├── README.md
├── assets
├── OpenSans-Regular.ttf
├── character_sheet_default_ref.jpg
├── characters_parts
│ ├── part_a.jpg
│ ├── part_b.jpg
│ └── part_c.jpg
├── openimages_classes.txt
├── plush_parts
│ ├── part_a.jpg
│ ├── part_b.jpg
│ └── part_c.jpg
└── product_parts
│ ├── part_a.jpg
│ ├── part_b.jpg
│ └── part_c.jpg
├── configs
├── infer
│ ├── infer_characters.yaml
│ ├── infer_plush.yaml
│ └── infer_products.yaml
└── train
│ └── train_characters.yaml
├── demo
├── app.py
└── pit.py
├── ip_adapter
├── __init__.py
├── attention_processor.py
├── attention_processor_faceid.py
├── custom_pipelines.py
├── ip_adapter.py
├── ip_adapter_faceid.py
├── ip_adapter_faceid_separate.py
├── resampler.py
├── sd3_attention_processor.py
├── test_resampler.py
└── utils.py
├── ip_lora_inference
├── download_ip_adapter.sh
├── download_loras.sh
└── inference_ip_lora.py
├── ip_lora_train
├── ip_adapter_for_lora.py
├── ip_lora_dataset.py
├── run_example.sh
├── sdxl_ip_lora_pipeline.py
└── train_ip_lora.py
├── ip_plus_space_exploration
├── download_directions.sh
├── download_ip_adapter.sh
├── edit_by_direction.py
├── find_direction.py
└── ip_model_utils.py
├── model
├── __init__.py
├── dit.py
└── pipeline_pit.py
├── pyproject.toml
├── scripts
├── generate_characters.py
├── generate_products.py
├── infer.py
└── train.py
├── training
├── __init__.py
├── coach.py
├── dataset.py
└── train_config.py
└── utils
├── __init__.py
├── bezier_utils.py
├── vis_utils.py
└── words_bank.py
/README.md:
--------------------------------------------------------------------------------
1 | # Piece it Together: Part-Based Concepting with IP-Priors
2 | > Elad Richardson, Kfir Goldberg, Yuval Alaluf, Daniel Cohen-Or
3 | > Tel Aviv University, Bria AI
4 | >
5 | > Advanced generative models excel at synthesizing images but often rely on text-based conditioning. Visual designers, however, often work beyond language, directly drawing inspiration from existing visual elements. In many cases, these elements represent only fragments of a potential concept-such as an uniquely structured wing, or a specific hairstyle-serving as inspiration for the artist to explore how they can come together creatively into a coherent whole. Recognizing this need, we introduce a generative framework that seamlessly integrates a partial set of user-provided visual components into a coherent composition while simultaneously sampling the missing parts needed to generate a plausible and complete concept. Our approach builds on a strong and underexplored representation space, extracted from IP-Adapter+, on which we train IP-Prior, a lightweight flow-matching model that synthesizes coherent compositions based on domain-specific priors, enabling diverse and context-aware generations. Additionally, we present a LoRA-based fine-tuning strategy that significantly improves prompt adherence in IP-Adapter+ for a given task, addressing its common trade-off between reconstruction quality and prompt adherence.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | Using a dedicated prior for the target domain, our method, Piece it Together (PiT), effectively completes missing information by seamlessly integrating given elements into a coherent composition while adding the necessary missing pieces needed for the complete concept to reside in the prior domain.
17 |
18 |
19 | ## Description :scroll:
20 | Official implementation of the paper "Piece it Together: Part-Based Concepting with IP-Priors"
21 |
22 |
23 | ## Table of contents
24 | - [Piece it Together: Part-Based Concepting with IP-Priors](#piece-it-together-part-based-concepting-with-ip-priors)
25 | - [Description :scroll:](#description-scroll)
26 | - [Table of contents](#table-of-contents)
27 | - [Getting started with PiT :rocket:](#getting-started-with-pit-rocket)
28 | - [Setup your environment](#setup-your-environment)
29 | - [Inference with PiT](#inference-with-pit)
30 | - [Training PiT](#training-pit)
31 | - [Inference with IP-LoRA](#inference-with-ip-lora)
32 | - [Training IP-LoRA](#training-ip-lora)
33 | - [Preparing your data](#preparing-your-data)
34 | - [Running the training script](#running-the-training-script)
35 | - [Exploring the IP+ space](#exploring-the-ip-space)
36 | - [Finding new directions](#finding-new-directions)
37 | - [Editing images with found directions](#editing-images-with-found-directions)
38 | - [Acknowledgments](#acknowledgments)
39 | - [Citation](#citation)
40 |
41 |
42 |
43 | ## Getting started with PiT :rocket:
44 |
45 | ### Setup your environment
46 |
47 | 1. Clone the repo:
48 |
49 | ```bash
50 | git clone https://github.com/eladrich/PiT
51 | cd PiT
52 | ```
53 |
54 | 2. Install `uv`:
55 |
56 | Instructions taken from [here](https://docs.astral.sh/uv/getting-started/installation/).
57 |
58 | For linux systems this should be:
59 | ```bash
60 | curl -LsSf https://astral.sh/uv/install.sh | sh
61 | source $HOME/.local/bin/env
62 | ```
63 |
64 | 3. Install the dependencies:
65 |
66 | ```bash
67 | uv sync
68 | ```
69 |
70 | 4. Activate your `.venv` and set the Python env:
71 |
72 | ```bash
73 | source .venv/bin/activate
74 | export PYTHONPATH=${PYTHONPATH}:${PWD}
75 | ```
76 |
77 |
78 |
79 | ## Inference with PiT
80 | | Domain | Examples | Link |
81 | |--------|--------------|----------------------------------------------------------------------------------------------|
82 | | Characters |
| [Here](https://huggingface.co/kfirgold99/Piece-it-Together/tree/main/models/characters_ckpt) |
83 | | Products |
| [Here](https://huggingface.co/kfirgold99/Piece-it-Together/tree/main/models/products_ckpt) |
84 | | Toys |
| [Here](https://huggingface.co/kfirgold99/Piece-it-Together/tree/main/models/plush_ckpt) |
85 |
86 |
87 | ## Training PiT
88 |
89 | ### Data Generation
90 | PiT assumes that the data is structured so that the the target images and part images are in the same directory with the naming convention being `image_name.jpg` for hte base image and `image_name_i.jpg` for the parts.
91 |
92 | To use a generated data see the sample scripts
93 | ```bash
94 | python -m scripts.generate_characters
95 | ```
96 |
97 | ```bash
98 | python -m scripts.generate_products
99 | ```
100 |
101 | ### Training
102 |
103 | For training see the `training/coach.py` file and the example below
104 |
105 | ``bash
106 | python -m scripts.train --config_path=configs/train/train_characters.yaml
107 | ``
108 |
109 | ## PiT Inference
110 |
111 | For inference see `scripts.infer.py` with the corresponding configs under `configs/infer`
112 |
113 | ```bash
114 | python -m scripts.infer --config_path=configs/infer/infer_characters.yaml
115 | ```
116 |
117 |
118 | ## Inference with IP-LoRA
119 |
120 | 1. Download the IP checkpoint and the LoRAs
121 |
122 | ```bash
123 | ip_lora_inference/download_ip_adapter.sh
124 | ip_lora_inference/download_loras.sh
125 | ```
126 |
127 | 2. Run inference with your preferred model
128 |
129 | example for running the styled-generation LoRA
130 |
131 | ```bash
132 | python ip_lora_inference/inference_ip_lora.py --lora_type "character_sheet" --lora_path "weights/character_sheet/pytorch_lora_weights.safetensors" --prompt "a character sheet displaying a creature, from several angles with 1 large front view in the middle, clean white background. In the background we can see half-completed, partially colored, sketches of different parts of the object" --output_dir "ip_lora_inference/character_sheet/" --ref_images_paths "assets/character_sheet_default_ref.jpg"
133 | --ip_adapter_path "weights/ip_adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin"
134 | ```
135 |
136 | ## Training IP-LoRA
137 |
138 | ### Preparing your data
139 |
140 | The expected data format for the training script is as follows:
141 |
142 | ```
143 | --base_dir/
144 | ----targets/
145 | ------img1.jpg
146 | ------img1.txt
147 | ------img2.jpg
148 | ------img2.txt
149 | ------img3.jpg
150 | ------img3.txt
151 | .
152 | .
153 | .
154 | ----refs/
155 | ------img1_ref.jpg
156 | ------img2_ref.jpg
157 | ------img3_ref.jpg
158 | .
159 | .
160 | .
161 | ```
162 |
163 | Where `imgX.jpg` is the target image for the input reference image `imgX_ref.jpg` with the prompt `imgX.txt`
164 |
165 | ### Running the training script
166 |
167 | For training a character-sheet styled generation LoRA, run the following command:
168 |
169 | ```bash
170 | python ./ip_lora_train/train_ip_lora.py \
171 | --rank 64 \
172 | --resolution 1024 \
173 | --validation_epochs 1 \
174 | --num_train_epochs 100 \
175 | --checkpointing_steps 50 \
176 | --train_batch_size 2 \
177 | --learning_rate 1e-4 \
178 | --dataloader_num_workers 1 \
179 | --gradient_accumulation_steps 8 \
180 | --dataset_base_dir \
181 | --prompt_mode character_sheet \
182 | --output_dir ./output/train_ip_lora/character_sheet
183 |
184 | ```
185 |
186 | and for the text adherence LoRA, run the following command:
187 |
188 | ```bash
189 | python ./ip_lora_train/train_ip_lora.py \
190 | --rank 64 \
191 | --resolution 1024 \
192 | --validation_epochs 1 \
193 | --num_train_epochs 100 \
194 | --checkpointing_steps 50 \
195 | --train_batch_size 2 \
196 | --learning_rate 1e-4 \
197 | --dataloader_num_workers 1 \
198 | --gradient_accumulation_steps 8 \
199 | --dataset_base_dir \
200 | --prompt_mode creature_in_scene \
201 | --output_dir ./output/train_ip_lora/creature_in_scene
202 | ```
203 |
204 | ## Exploring the IP+ space
205 |
206 | Start by downloading the needed IP+ checkpoint and the directions presented in the paper:
207 |
208 | ```bash
209 | ip_plus_space_exploration/download_directions.sh
210 | ip_plus_space_exploration/download_ip_adapter.sh
211 | ```
212 |
213 | ### Finding new directions
214 |
215 | To find a direction in the IP+ space from "class1" (e.g. "scrawny") to "class2" (e.g. "muscular"):
216 |
217 | 1. Create `class1_dir` and `class2_dir` containing images of the source and target classes respectively
218 |
219 | 2. Run the `find_direction` script:
220 |
221 | ```bash
222 | python ip_plus_space_exploration/find_direction.py --class1_dir --class2_dir --output_dir ./ip_directions --ip_model_type "plus"
223 | ```
224 |
225 | ### Editing images with found directions
226 |
227 | Use the direction found in the previous stage, or one downloaded from [HuggingFace](https://huggingface.co/kfirgold99/Piece-it-Together) in the previous stage.
228 |
229 | ```bash
230 | python ip_plus_space_exploration/edit_by_direction.py --ip_model_type "plus" --image_path --direction_path --direction_type "ip" --output_dir "./edit_by_direction/"
231 | ```
232 |
233 | ## Acknowledgments
234 |
235 | Code is based on
236 | - https://github.com/pOpsPaper/pOps
237 | - https://github.com/cloneofsimo/minRF by the great [@cloneofsimo](https://github.com/cloneofsimo)
238 |
239 | ## Citation
240 |
241 | If you use this code for your research, please cite the following paper:
242 |
243 | ```
244 | @misc{richardson2025piece,
245 | title={Piece it Together: Part-Based Concepting with IP-Priors},
246 | author={Richardson, Elad and Goldberg, Kfir and Alaluf, Yuval and Cohen-Or, Daniel},
247 | year={2025},
248 | eprint={2503.10365},
249 | archivePrefix={arXiv},
250 | primaryClass={cs.CV},
251 | url={https://arxiv.org/abs/2503.10365},
252 | }
253 | ```
--------------------------------------------------------------------------------
/assets/OpenSans-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/OpenSans-Regular.ttf
--------------------------------------------------------------------------------
/assets/character_sheet_default_ref.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/character_sheet_default_ref.jpg
--------------------------------------------------------------------------------
/assets/characters_parts/part_a.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/characters_parts/part_a.jpg
--------------------------------------------------------------------------------
/assets/characters_parts/part_b.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/characters_parts/part_b.jpg
--------------------------------------------------------------------------------
/assets/characters_parts/part_c.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/characters_parts/part_c.jpg
--------------------------------------------------------------------------------
/assets/plush_parts/part_a.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/plush_parts/part_a.jpg
--------------------------------------------------------------------------------
/assets/plush_parts/part_b.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/plush_parts/part_b.jpg
--------------------------------------------------------------------------------
/assets/plush_parts/part_c.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/plush_parts/part_c.jpg
--------------------------------------------------------------------------------
/assets/product_parts/part_a.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/product_parts/part_a.jpg
--------------------------------------------------------------------------------
/assets/product_parts/part_b.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/product_parts/part_b.jpg
--------------------------------------------------------------------------------
/assets/product_parts/part_c.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/assets/product_parts/part_c.jpg
--------------------------------------------------------------------------------
/configs/infer/infer_characters.yaml:
--------------------------------------------------------------------------------
1 | prior_path: models/characters_ckpt/prior.ckpt
2 | prior_repo: kfirgold99/Piece-it-Together
3 | crops_dir: assets/characters_parts
4 | output_dir: inference/characters
--------------------------------------------------------------------------------
/configs/infer/infer_plush.yaml:
--------------------------------------------------------------------------------
1 | prior_path: models/plush_ckpt/prior.ckpt
2 | prior_repo: kfirgold99/Piece-it-Together
3 | crops_dir: assets/plush_parts
4 | output_dir: inference/plush
--------------------------------------------------------------------------------
/configs/infer/infer_products.yaml:
--------------------------------------------------------------------------------
1 | prior_path: models/products_ckpt/prior.ckpt
2 | prior_repo: kfirgold99/Piece-it-Together
3 | crops_dir: assets/product_parts
4 | output_dir: inference/products
--------------------------------------------------------------------------------
/configs/train/train_characters.yaml:
--------------------------------------------------------------------------------
1 | dataset_path: 'datasets/generated/characters'
2 | val_dataset_path: 'datasets/generated/characters'
3 | output_dir: 'training_results/train_characters'
4 | train_batch_size: 64
5 | num_layers: 4
6 | max_crops: 7
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/demo/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import spaces
3 | from pit import PiTDemoPipeline
4 |
5 | BLOCK_WIDTH = 300
6 | BLOCK_HEIGHT = 360
7 | FONT_SIZE = 3.5
8 |
9 | pit_pipeline = PiTDemoPipeline(
10 | prior_repo="kfirgold99/Piece-it-Together", prior_path="models/characters_ckpt/prior.ckpt"
11 | )
12 |
13 |
14 | @spaces.GPU
15 | def run_character_generation(part_1, part_2, part_3, seed=None):
16 | crops_paths = [part_1, part_2, part_3]
17 | image = pit_pipeline.run(crops_paths=crops_paths, seed=seed, n_images=1)[0]
18 | return image
19 |
20 |
21 | with gr.Blocks(css="style.css") as demo:
22 | gr.HTML(
23 | """Piece it Together: Part-Based Concepting with IP-Priors
"""
24 | )
25 | gr.HTML(
26 | ''
27 | )
28 | gr.HTML(
29 | 'Piece it Together (PiT) combines different input parts to generate a complete concept in a prior domain.
'
30 | )
31 | with gr.Row(equal_height=True, elem_classes="justified-element"):
32 | with gr.Column(scale=0, min_width=BLOCK_WIDTH):
33 | part_1 = gr.Image(label="Upload part 1 (or keep empty)", type="filepath", width=BLOCK_WIDTH, height=BLOCK_HEIGHT)
34 | with gr.Column(scale=0, min_width=BLOCK_WIDTH):
35 | part_2 = gr.Image(label="Upload part 2 (or keep empty)", type="filepath", width=BLOCK_WIDTH, height=BLOCK_HEIGHT)
36 | with gr.Column(scale=0, min_width=BLOCK_WIDTH):
37 | part_3 = gr.Image(label="Upload part 3 (or keep empty)", type="filepath", width=BLOCK_WIDTH, height=BLOCK_HEIGHT)
38 | with gr.Column(scale=0, min_width=BLOCK_WIDTH):
39 | output_eq_1 = gr.Image(label="Output", width=BLOCK_WIDTH, height=BLOCK_HEIGHT)
40 | with gr.Row(equal_height=True, elem_classes="justified-element"):
41 | run_button = gr.Button("Create your character!", elem_classes="small-elem")
42 | run_button.click(fn=run_character_generation, inputs=[part_1, part_2, part_3], outputs=[output_eq_1])
43 | with gr.Row(equal_height=True, elem_classes="justified-element"):
44 | pass
45 |
46 | with gr.Row(equal_height=True, elem_classes="justified-element"):
47 | with gr.Column(scale=1):
48 | examples = [
49 | [
50 | "assets/characters_parts/part_a.jpg",
51 | "assets/characters_parts/part_b.jpg",
52 | "assets/characters_parts/part_c.jpg",
53 | ]
54 | ]
55 | gr.Examples(
56 | examples=examples,
57 | inputs=[part_1, part_2, part_3],
58 | outputs=[output_eq_1],
59 | fn=run_character_generation,
60 | cache_examples=False,
61 | )
62 |
63 | demo.queue().launch(share=True)
64 |
--------------------------------------------------------------------------------
/demo/pit.py:
--------------------------------------------------------------------------------
1 | import random
2 | from pathlib import Path
3 | from typing import Optional
4 |
5 | import numpy as np
6 | import pyrallis
7 | import torch
8 | from diffusers import (
9 | StableDiffusionXLPipeline,
10 | )
11 | from huggingface_hub import hf_hub_download
12 | from PIL import Image
13 |
14 | from ip_adapter import IPAdapterPlusXL
15 | from model.dit import DiT_Llama
16 | from model.pipeline_pit import PiTPipeline
17 | from training.train_config import TrainConfig
18 |
19 |
20 | def paste_on_background(image, background, min_scale=0.4, max_scale=0.8, scale=None):
21 | # Calculate aspect ratio and determine resizing based on the smaller dimension of the background
22 | aspect_ratio = image.width / image.height
23 | scale = random.uniform(min_scale, max_scale) if scale is None else scale
24 | new_width = int(min(background.width, background.height * aspect_ratio) * scale)
25 | new_height = int(new_width / aspect_ratio)
26 |
27 | # Resize image and calculate position
28 | image = image.resize((new_width, new_height), resample=Image.LANCZOS)
29 | pos_x = random.randint(0, background.width - new_width)
30 | pos_y = random.randint(0, background.height - new_height)
31 |
32 | # Paste the image using its alpha channel as mask if present
33 | background.paste(image, (pos_x, pos_y), image if "A" in image.mode else None)
34 | return background
35 |
36 |
37 | def set_seed(seed: int):
38 | """Ensures reproducibility across multiple libraries."""
39 | random.seed(seed) # Python random module
40 | np.random.seed(seed) # NumPy random module
41 | torch.manual_seed(seed) # PyTorch CPU random seed
42 | torch.cuda.manual_seed_all(seed) # PyTorch GPU random seed
43 | torch.backends.cudnn.deterministic = True # Ensures deterministic behavior
44 | torch.backends.cudnn.benchmark = False # Disable benchmarking to avoid randomness
45 |
46 |
47 | class PiTDemoPipeline:
48 | def __init__(self, prior_repo: str, prior_path: str):
49 | # Download model and config
50 | prior_ckpt_path = hf_hub_download(
51 | repo_id=prior_repo,
52 | filename=str(prior_path),
53 | local_dir="pretrained_models",
54 | )
55 | prior_cfg_path = hf_hub_download(
56 | repo_id=prior_repo, filename=str(Path(prior_path).parent / "cfg.yaml"), local_dir="pretrained_models"
57 | )
58 | self.model_cfg: TrainConfig = pyrallis.load(TrainConfig, open(prior_cfg_path, "r"))
59 |
60 | self.weight_dtype = torch.float32
61 | self.device = "cuda:0"
62 | prior = DiT_Llama(
63 | embedding_dim=2048,
64 | hidden_dim=self.model_cfg.hidden_dim,
65 | n_layers=self.model_cfg.num_layers,
66 | n_heads=self.model_cfg.num_attention_heads,
67 | )
68 | prior.load_state_dict(torch.load(prior_ckpt_path))
69 | image_pipe = StableDiffusionXLPipeline.from_pretrained(
70 | "stabilityai/stable-diffusion-xl-base-1.0",
71 | torch_dtype=torch.float16,
72 | add_watermarker=False,
73 | )
74 | ip_ckpt_path = hf_hub_download(
75 | repo_id="h94/IP-Adapter",
76 | filename="ip-adapter-plus_sdxl_vit-h.bin",
77 | subfolder="sdxl_models",
78 | local_dir="pretrained_models",
79 | )
80 |
81 | self.ip_model = IPAdapterPlusXL(
82 | image_pipe,
83 | "models/image_encoder",
84 | ip_ckpt_path,
85 | self.device,
86 | num_tokens=16,
87 | )
88 | self.image_processor = self.ip_model.clip_image_processor
89 |
90 | empty_image = Image.new("RGB", (256, 256), (255, 255, 255))
91 | zero_image = torch.Tensor(self.image_processor(empty_image)["pixel_values"][0])
92 | self.zero_image_embeds = self.ip_model.get_image_embeds(zero_image.unsqueeze(0), skip_uncond=True)
93 |
94 | prior_pipeline = PiTPipeline(
95 | prior=prior,
96 | )
97 | self.prior_pipeline = prior_pipeline.to(self.device)
98 | set_seed(42)
99 |
100 | def run(self, crops_paths: list[str], scale: float = 2.0, seed: Optional[int] = None, n_images: int = 1):
101 | if seed is not None:
102 | set_seed(seed)
103 | processed_crops = []
104 | input_images = []
105 |
106 | crops_paths = [None] + crops_paths
107 | # Extend to >3 with Nones
108 | while len(crops_paths) < 3:
109 | crops_paths.append(None)
110 |
111 | for path_ind, path in enumerate(crops_paths):
112 | if path is None:
113 | image = Image.new("RGB", (224, 224), (255, 255, 255))
114 | else:
115 | image = Image.open(path).convert("RGB")
116 | if path_ind > 0 or not self.model_cfg.use_ref:
117 | background = Image.new("RGB", (1024, 1024), (255, 255, 255))
118 | image = paste_on_background(image, background, scale=0.92)
119 | else:
120 | image = image.resize((1024, 1024))
121 | input_images.append(image)
122 | # Name should be parent directory name
123 | processed_image = (
124 | torch.Tensor(self.image_processor(image)["pixel_values"][0])
125 | .to(self.device)
126 | .unsqueeze(0)
127 | .to(self.weight_dtype)
128 | )
129 | processed_crops.append(processed_image)
130 |
131 | image_embed_inputs = []
132 | for crop_ind in range(len(processed_crops)):
133 | image_embed_inputs.append(self.ip_model.get_image_embeds(processed_crops[crop_ind], skip_uncond=True))
134 | crops_input_sequence = torch.cat(image_embed_inputs, dim=1)
135 | generated_images = []
136 | for _ in range(n_images):
137 | seed = random.randint(0, 1000000)
138 | for curr_scale in [scale]:
139 | negative_cond_sequence = torch.zeros_like(crops_input_sequence)
140 | embeds_len = self.zero_image_embeds.shape[1]
141 | for i in range(0, negative_cond_sequence.shape[1], embeds_len):
142 | negative_cond_sequence[:, i : i + embeds_len] = self.zero_image_embeds.detach()
143 |
144 | img_emb = self.prior_pipeline(
145 | cond_sequence=crops_input_sequence,
146 | negative_cond_sequence=negative_cond_sequence,
147 | num_inference_steps=25,
148 | num_images_per_prompt=1,
149 | guidance_scale=curr_scale,
150 | generator=torch.Generator(device="cuda").manual_seed(seed),
151 | ).image_embeds
152 |
153 | for seed_2 in range(1):
154 | images = self.ip_model.generate(
155 | image_prompt_embeds=img_emb,
156 | num_samples=1,
157 | num_inference_steps=50,
158 | )
159 | generated_images += images
160 |
161 | return generated_images
162 |
--------------------------------------------------------------------------------
/ip_adapter/__init__.py:
--------------------------------------------------------------------------------
1 | from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
2 |
3 | __all__ = [
4 | "IPAdapter",
5 | "IPAdapterPlus",
6 | "IPAdapterPlusXL",
7 | "IPAdapterXL",
8 | "IPAdapterFull",
9 | ]
10 |
--------------------------------------------------------------------------------
/ip_adapter/attention_processor.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class AttnProcessor(nn.Module):
8 | r"""
9 | Default processor for performing attention-related computations.
10 | """
11 |
12 | def __init__(
13 | self,
14 | hidden_size=None,
15 | cross_attention_dim=None,
16 | ):
17 | super().__init__()
18 |
19 | def __call__(
20 | self,
21 | attn,
22 | hidden_states,
23 | encoder_hidden_states=None,
24 | attention_mask=None,
25 | temb=None,
26 | *args,
27 | **kwargs,
28 | ):
29 | residual = hidden_states
30 |
31 | if attn.spatial_norm is not None:
32 | hidden_states = attn.spatial_norm(hidden_states, temb)
33 |
34 | input_ndim = hidden_states.ndim
35 |
36 | if input_ndim == 4:
37 | batch_size, channel, height, width = hidden_states.shape
38 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
39 |
40 | batch_size, sequence_length, _ = (
41 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
42 | )
43 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
44 |
45 | if attn.group_norm is not None:
46 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
47 |
48 | query = attn.to_q(hidden_states)
49 |
50 | if encoder_hidden_states is None:
51 | encoder_hidden_states = hidden_states
52 | elif attn.norm_cross:
53 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
54 |
55 | key = attn.to_k(encoder_hidden_states)
56 | value = attn.to_v(encoder_hidden_states)
57 |
58 | query = attn.head_to_batch_dim(query)
59 | key = attn.head_to_batch_dim(key)
60 | value = attn.head_to_batch_dim(value)
61 |
62 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
63 | hidden_states = torch.bmm(attention_probs, value)
64 | hidden_states = attn.batch_to_head_dim(hidden_states)
65 |
66 | # linear proj
67 | hidden_states = attn.to_out[0](hidden_states)
68 | # dropout
69 | hidden_states = attn.to_out[1](hidden_states)
70 |
71 | if input_ndim == 4:
72 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
73 |
74 | if attn.residual_connection:
75 | hidden_states = hidden_states + residual
76 |
77 | hidden_states = hidden_states / attn.rescale_output_factor
78 |
79 | return hidden_states
80 |
81 |
82 | class IPAttnProcessor(nn.Module):
83 | r"""
84 | Attention processor for IP-Adapater.
85 | Args:
86 | hidden_size (`int`):
87 | The hidden size of the attention layer.
88 | cross_attention_dim (`int`):
89 | The number of channels in the `encoder_hidden_states`.
90 | scale (`float`, defaults to 1.0):
91 | the weight scale of image prompt.
92 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
93 | The context length of the image features.
94 | """
95 |
96 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
97 | super().__init__()
98 |
99 | self.hidden_size = hidden_size
100 | self.cross_attention_dim = cross_attention_dim
101 | self.scale = scale
102 | self.num_tokens = num_tokens
103 |
104 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
106 |
107 | def __call__(
108 | self,
109 | attn,
110 | hidden_states,
111 | encoder_hidden_states=None,
112 | attention_mask=None,
113 | temb=None,
114 | *args,
115 | **kwargs,
116 | ):
117 | residual = hidden_states
118 |
119 | if attn.spatial_norm is not None:
120 | hidden_states = attn.spatial_norm(hidden_states, temb)
121 |
122 | input_ndim = hidden_states.ndim
123 |
124 | if input_ndim == 4:
125 | batch_size, channel, height, width = hidden_states.shape
126 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
127 |
128 | batch_size, sequence_length, _ = (
129 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
130 | )
131 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
132 |
133 | if attn.group_norm is not None:
134 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
135 |
136 | query = attn.to_q(hidden_states)
137 |
138 | if encoder_hidden_states is None:
139 | encoder_hidden_states = hidden_states
140 | else:
141 | # get encoder_hidden_states, ip_hidden_states
142 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
143 | encoder_hidden_states, ip_hidden_states = (
144 | encoder_hidden_states[:, :end_pos, :],
145 | encoder_hidden_states[:, end_pos:, :],
146 | )
147 | if attn.norm_cross:
148 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
149 |
150 | key = attn.to_k(encoder_hidden_states)
151 | value = attn.to_v(encoder_hidden_states)
152 |
153 | query = attn.head_to_batch_dim(query)
154 | key = attn.head_to_batch_dim(key)
155 | value = attn.head_to_batch_dim(value)
156 |
157 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
158 | hidden_states = torch.bmm(attention_probs, value)
159 | hidden_states = attn.batch_to_head_dim(hidden_states)
160 |
161 | # for ip-adapter
162 | ip_key = self.to_k_ip(ip_hidden_states)
163 | ip_value = self.to_v_ip(ip_hidden_states)
164 |
165 | ip_key = attn.head_to_batch_dim(ip_key)
166 | ip_value = attn.head_to_batch_dim(ip_value)
167 |
168 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
169 | self.attn_map = ip_attention_probs
170 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
171 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
172 |
173 | hidden_states = hidden_states + self.scale * ip_hidden_states
174 |
175 | # linear proj
176 | hidden_states = attn.to_out[0](hidden_states)
177 | # dropout
178 | hidden_states = attn.to_out[1](hidden_states)
179 |
180 | if input_ndim == 4:
181 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
182 |
183 | if attn.residual_connection:
184 | hidden_states = hidden_states + residual
185 |
186 | hidden_states = hidden_states / attn.rescale_output_factor
187 |
188 | return hidden_states
189 |
190 |
191 | class AttnProcessor2_0(torch.nn.Module):
192 | r"""
193 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
194 | """
195 |
196 | def __init__(
197 | self,
198 | hidden_size=None,
199 | cross_attention_dim=None,
200 | ):
201 | super().__init__()
202 | if not hasattr(F, "scaled_dot_product_attention"):
203 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
204 |
205 | def __call__(
206 | self,
207 | attn,
208 | hidden_states,
209 | encoder_hidden_states=None,
210 | attention_mask=None,
211 | temb=None,
212 | *args,
213 | **kwargs,
214 | ):
215 | residual = hidden_states
216 |
217 | if attn.spatial_norm is not None:
218 | hidden_states = attn.spatial_norm(hidden_states, temb)
219 |
220 | input_ndim = hidden_states.ndim
221 |
222 | if input_ndim == 4:
223 | batch_size, channel, height, width = hidden_states.shape
224 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
225 |
226 | batch_size, sequence_length, _ = (
227 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
228 | )
229 |
230 | if attention_mask is not None:
231 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
232 | # scaled_dot_product_attention expects attention_mask shape to be
233 | # (batch, heads, source_length, target_length)
234 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
235 |
236 | if attn.group_norm is not None:
237 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
238 |
239 | query = attn.to_q(hidden_states)
240 |
241 | if encoder_hidden_states is None:
242 | encoder_hidden_states = hidden_states
243 | elif attn.norm_cross:
244 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
245 |
246 | key = attn.to_k(encoder_hidden_states)
247 | value = attn.to_v(encoder_hidden_states)
248 |
249 | inner_dim = key.shape[-1]
250 | head_dim = inner_dim // attn.heads
251 |
252 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
253 |
254 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
255 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
256 |
257 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
258 | # TODO: add support for attn.scale when we move to Torch 2.1
259 | hidden_states = F.scaled_dot_product_attention(
260 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
261 | )
262 |
263 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
264 | hidden_states = hidden_states.to(query.dtype)
265 |
266 | # linear proj
267 | hidden_states = attn.to_out[0](hidden_states)
268 | # dropout
269 | hidden_states = attn.to_out[1](hidden_states)
270 |
271 | if input_ndim == 4:
272 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
273 |
274 | if attn.residual_connection:
275 | hidden_states = hidden_states + residual
276 |
277 | hidden_states = hidden_states / attn.rescale_output_factor
278 |
279 | return hidden_states
280 |
281 |
282 | class IPAttnProcessor2_0(torch.nn.Module):
283 | r"""
284 | Attention processor for IP-Adapater for PyTorch 2.0.
285 | Args:
286 | hidden_size (`int`):
287 | The hidden size of the attention layer.
288 | cross_attention_dim (`int`):
289 | The number of channels in the `encoder_hidden_states`.
290 | scale (`float`, defaults to 1.0):
291 | the weight scale of image prompt.
292 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
293 | The context length of the image features.
294 | """
295 |
296 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
297 | super().__init__()
298 |
299 | if not hasattr(F, "scaled_dot_product_attention"):
300 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
301 |
302 | self.hidden_size = hidden_size
303 | self.cross_attention_dim = cross_attention_dim
304 | self.scale = scale
305 | self.num_tokens = num_tokens
306 |
307 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
308 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
309 |
310 | def __call__(
311 | self,
312 | attn,
313 | hidden_states,
314 | encoder_hidden_states=None,
315 | attention_mask=None,
316 | temb=None,
317 | *args,
318 | **kwargs,
319 | ):
320 | residual = hidden_states
321 |
322 | if attn.spatial_norm is not None:
323 | hidden_states = attn.spatial_norm(hidden_states, temb)
324 |
325 | input_ndim = hidden_states.ndim
326 |
327 | if input_ndim == 4:
328 | batch_size, channel, height, width = hidden_states.shape
329 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
330 |
331 | batch_size, sequence_length, _ = (
332 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
333 | )
334 |
335 | if attention_mask is not None:
336 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
337 | # scaled_dot_product_attention expects attention_mask shape to be
338 | # (batch, heads, source_length, target_length)
339 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
340 |
341 | if attn.group_norm is not None:
342 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
343 |
344 | query = attn.to_q(hidden_states)
345 |
346 | if encoder_hidden_states is None:
347 | encoder_hidden_states = hidden_states
348 | else:
349 | # get encoder_hidden_states, ip_hidden_states
350 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
351 | encoder_hidden_states, ip_hidden_states = (
352 | encoder_hidden_states[:, :end_pos, :],
353 | encoder_hidden_states[:, end_pos:, :],
354 | )
355 | if attn.norm_cross:
356 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
357 |
358 | key = attn.to_k(encoder_hidden_states)
359 | value = attn.to_v(encoder_hidden_states)
360 |
361 | inner_dim = key.shape[-1]
362 | head_dim = inner_dim // attn.heads
363 |
364 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
365 |
366 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
367 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
368 |
369 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
370 | # TODO: add support for attn.scale when we move to Torch 2.1
371 | hidden_states = F.scaled_dot_product_attention(
372 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
373 | )
374 |
375 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
376 | hidden_states = hidden_states.to(query.dtype)
377 |
378 | # for ip-adapter
379 | ip_key = self.to_k_ip(ip_hidden_states)
380 | ip_value = self.to_v_ip(ip_hidden_states)
381 |
382 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
383 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
384 |
385 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
386 | # TODO: add support for attn.scale when we move to Torch 2.1
387 | ip_hidden_states = F.scaled_dot_product_attention(
388 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
389 | )
390 | with torch.no_grad():
391 | self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
392 | #print(self.attn_map.shape)
393 |
394 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
395 | ip_hidden_states = ip_hidden_states.to(query.dtype)
396 |
397 | hidden_states = hidden_states + self.scale * ip_hidden_states
398 |
399 | # linear proj
400 | hidden_states = attn.to_out[0](hidden_states)
401 | # dropout
402 | hidden_states = attn.to_out[1](hidden_states)
403 |
404 | if input_ndim == 4:
405 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
406 |
407 | if attn.residual_connection:
408 | hidden_states = hidden_states + residual
409 |
410 | hidden_states = hidden_states / attn.rescale_output_factor
411 |
412 | return hidden_states
413 |
414 |
415 | ## for controlnet
416 | class CNAttnProcessor:
417 | r"""
418 | Default processor for performing attention-related computations.
419 | """
420 |
421 | def __init__(self, num_tokens=4):
422 | self.num_tokens = num_tokens
423 |
424 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,):
425 | residual = hidden_states
426 |
427 | if attn.spatial_norm is not None:
428 | hidden_states = attn.spatial_norm(hidden_states, temb)
429 |
430 | input_ndim = hidden_states.ndim
431 |
432 | if input_ndim == 4:
433 | batch_size, channel, height, width = hidden_states.shape
434 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
435 |
436 | batch_size, sequence_length, _ = (
437 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
438 | )
439 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
440 |
441 | if attn.group_norm is not None:
442 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
443 |
444 | query = attn.to_q(hidden_states)
445 |
446 | if encoder_hidden_states is None:
447 | encoder_hidden_states = hidden_states
448 | else:
449 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
450 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
451 | if attn.norm_cross:
452 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
453 |
454 | key = attn.to_k(encoder_hidden_states)
455 | value = attn.to_v(encoder_hidden_states)
456 |
457 | query = attn.head_to_batch_dim(query)
458 | key = attn.head_to_batch_dim(key)
459 | value = attn.head_to_batch_dim(value)
460 |
461 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
462 | hidden_states = torch.bmm(attention_probs, value)
463 | hidden_states = attn.batch_to_head_dim(hidden_states)
464 |
465 | # linear proj
466 | hidden_states = attn.to_out[0](hidden_states)
467 | # dropout
468 | hidden_states = attn.to_out[1](hidden_states)
469 |
470 | if input_ndim == 4:
471 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
472 |
473 | if attn.residual_connection:
474 | hidden_states = hidden_states + residual
475 |
476 | hidden_states = hidden_states / attn.rescale_output_factor
477 |
478 | return hidden_states
479 |
480 |
481 | class CNAttnProcessor2_0:
482 | r"""
483 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
484 | """
485 |
486 | def __init__(self, num_tokens=4):
487 | if not hasattr(F, "scaled_dot_product_attention"):
488 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
489 | self.num_tokens = num_tokens
490 |
491 | def __call__(
492 | self,
493 | attn,
494 | hidden_states,
495 | encoder_hidden_states=None,
496 | attention_mask=None,
497 | temb=None,
498 | *args,
499 | **kwargs,
500 | ):
501 | residual = hidden_states
502 |
503 | if attn.spatial_norm is not None:
504 | hidden_states = attn.spatial_norm(hidden_states, temb)
505 |
506 | input_ndim = hidden_states.ndim
507 |
508 | if input_ndim == 4:
509 | batch_size, channel, height, width = hidden_states.shape
510 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
511 |
512 | batch_size, sequence_length, _ = (
513 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
514 | )
515 |
516 | if attention_mask is not None:
517 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
518 | # scaled_dot_product_attention expects attention_mask shape to be
519 | # (batch, heads, source_length, target_length)
520 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
521 |
522 | if attn.group_norm is not None:
523 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
524 |
525 | query = attn.to_q(hidden_states)
526 |
527 | if encoder_hidden_states is None:
528 | encoder_hidden_states = hidden_states
529 | else:
530 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
531 | encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
532 | if attn.norm_cross:
533 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
534 |
535 | key = attn.to_k(encoder_hidden_states)
536 | value = attn.to_v(encoder_hidden_states)
537 |
538 | inner_dim = key.shape[-1]
539 | head_dim = inner_dim // attn.heads
540 |
541 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
542 |
543 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
544 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
545 |
546 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
547 | # TODO: add support for attn.scale when we move to Torch 2.1
548 | hidden_states = F.scaled_dot_product_attention(
549 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
550 | )
551 |
552 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
553 | hidden_states = hidden_states.to(query.dtype)
554 |
555 | # linear proj
556 | hidden_states = attn.to_out[0](hidden_states)
557 | # dropout
558 | hidden_states = attn.to_out[1](hidden_states)
559 |
560 | if input_ndim == 4:
561 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
562 |
563 | if attn.residual_connection:
564 | hidden_states = hidden_states + residual
565 |
566 | hidden_states = hidden_states / attn.rescale_output_factor
567 |
568 | return hidden_states
569 |
--------------------------------------------------------------------------------
/ip_adapter/attention_processor_faceid.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from diffusers.models.lora import LoRALinearLayer
7 |
8 |
9 | class LoRAAttnProcessor(nn.Module):
10 | r"""
11 | Default processor for performing attention-related computations.
12 | """
13 |
14 | def __init__(
15 | self,
16 | hidden_size=None,
17 | cross_attention_dim=None,
18 | rank=4,
19 | network_alpha=None,
20 | lora_scale=1.0,
21 | ):
22 | super().__init__()
23 |
24 | self.rank = rank
25 | self.lora_scale = lora_scale
26 |
27 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
28 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
29 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
30 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
31 |
32 | def __call__(
33 | self,
34 | attn,
35 | hidden_states,
36 | encoder_hidden_states=None,
37 | attention_mask=None,
38 | temb=None,
39 | *args,
40 | **kwargs,
41 | ):
42 | residual = hidden_states
43 |
44 | if attn.spatial_norm is not None:
45 | hidden_states = attn.spatial_norm(hidden_states, temb)
46 |
47 | input_ndim = hidden_states.ndim
48 |
49 | if input_ndim == 4:
50 | batch_size, channel, height, width = hidden_states.shape
51 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
52 |
53 | batch_size, sequence_length, _ = (
54 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
55 | )
56 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
57 |
58 | if attn.group_norm is not None:
59 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
60 |
61 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
62 |
63 | if encoder_hidden_states is None:
64 | encoder_hidden_states = hidden_states
65 | elif attn.norm_cross:
66 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
67 |
68 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
69 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
70 |
71 | query = attn.head_to_batch_dim(query)
72 | key = attn.head_to_batch_dim(key)
73 | value = attn.head_to_batch_dim(value)
74 |
75 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
76 | hidden_states = torch.bmm(attention_probs, value)
77 | hidden_states = attn.batch_to_head_dim(hidden_states)
78 |
79 | # linear proj
80 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
81 | # dropout
82 | hidden_states = attn.to_out[1](hidden_states)
83 |
84 | if input_ndim == 4:
85 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
86 |
87 | if attn.residual_connection:
88 | hidden_states = hidden_states + residual
89 |
90 | hidden_states = hidden_states / attn.rescale_output_factor
91 |
92 | return hidden_states
93 |
94 |
95 | class LoRAIPAttnProcessor(nn.Module):
96 | r"""
97 | Attention processor for IP-Adapater.
98 | Args:
99 | hidden_size (`int`):
100 | The hidden size of the attention layer.
101 | cross_attention_dim (`int`):
102 | The number of channels in the `encoder_hidden_states`.
103 | scale (`float`, defaults to 1.0):
104 | the weight scale of image prompt.
105 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
106 | The context length of the image features.
107 | """
108 |
109 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
110 | super().__init__()
111 |
112 | self.rank = rank
113 | self.lora_scale = lora_scale
114 |
115 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
116 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
117 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
118 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
119 |
120 | self.hidden_size = hidden_size
121 | self.cross_attention_dim = cross_attention_dim
122 | self.scale = scale
123 | self.num_tokens = num_tokens
124 |
125 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
126 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
127 |
128 | def __call__(
129 | self,
130 | attn,
131 | hidden_states,
132 | encoder_hidden_states=None,
133 | attention_mask=None,
134 | temb=None,
135 | *args,
136 | **kwargs,
137 | ):
138 | residual = hidden_states
139 |
140 | if attn.spatial_norm is not None:
141 | hidden_states = attn.spatial_norm(hidden_states, temb)
142 |
143 | input_ndim = hidden_states.ndim
144 |
145 | if input_ndim == 4:
146 | batch_size, channel, height, width = hidden_states.shape
147 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
148 |
149 | batch_size, sequence_length, _ = (
150 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
151 | )
152 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
153 |
154 | if attn.group_norm is not None:
155 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
156 |
157 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
158 |
159 | if encoder_hidden_states is None:
160 | encoder_hidden_states = hidden_states
161 | else:
162 | # get encoder_hidden_states, ip_hidden_states
163 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
164 | encoder_hidden_states, ip_hidden_states = (
165 | encoder_hidden_states[:, :end_pos, :],
166 | encoder_hidden_states[:, end_pos:, :],
167 | )
168 | if attn.norm_cross:
169 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
170 |
171 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
172 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
173 |
174 | query = attn.head_to_batch_dim(query)
175 | key = attn.head_to_batch_dim(key)
176 | value = attn.head_to_batch_dim(value)
177 |
178 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
179 | hidden_states = torch.bmm(attention_probs, value)
180 | hidden_states = attn.batch_to_head_dim(hidden_states)
181 |
182 | # for ip-adapter
183 | ip_key = self.to_k_ip(ip_hidden_states)
184 | ip_value = self.to_v_ip(ip_hidden_states)
185 |
186 | ip_key = attn.head_to_batch_dim(ip_key)
187 | ip_value = attn.head_to_batch_dim(ip_value)
188 |
189 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
190 | self.attn_map = ip_attention_probs
191 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
192 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
193 |
194 | hidden_states = hidden_states + self.scale * ip_hidden_states
195 |
196 | # linear proj
197 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
198 | # dropout
199 | hidden_states = attn.to_out[1](hidden_states)
200 |
201 | if input_ndim == 4:
202 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
203 |
204 | if attn.residual_connection:
205 | hidden_states = hidden_states + residual
206 |
207 | hidden_states = hidden_states / attn.rescale_output_factor
208 |
209 | return hidden_states
210 |
211 |
212 | class LoRAAttnProcessor2_0(nn.Module):
213 |
214 | r"""
215 | Default processor for performing attention-related computations.
216 | """
217 |
218 | def __init__(
219 | self,
220 | hidden_size=None,
221 | cross_attention_dim=None,
222 | rank=4,
223 | network_alpha=None,
224 | lora_scale=1.0,
225 | ):
226 | super().__init__()
227 |
228 | self.rank = rank
229 | self.lora_scale = lora_scale
230 |
231 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
232 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
233 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
234 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
235 |
236 | def __call__(
237 | self,
238 | attn,
239 | hidden_states,
240 | encoder_hidden_states=None,
241 | attention_mask=None,
242 | temb=None,
243 | *args,
244 | **kwargs,
245 | ):
246 | residual = hidden_states
247 |
248 | if attn.spatial_norm is not None:
249 | hidden_states = attn.spatial_norm(hidden_states, temb)
250 |
251 | input_ndim = hidden_states.ndim
252 |
253 | if input_ndim == 4:
254 | batch_size, channel, height, width = hidden_states.shape
255 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
256 |
257 | batch_size, sequence_length, _ = (
258 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
259 | )
260 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
261 |
262 | if attn.group_norm is not None:
263 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
264 |
265 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
266 |
267 | if encoder_hidden_states is None:
268 | encoder_hidden_states = hidden_states
269 | elif attn.norm_cross:
270 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
271 |
272 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
273 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
274 |
275 | inner_dim = key.shape[-1]
276 | head_dim = inner_dim // attn.heads
277 |
278 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
279 |
280 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
281 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282 |
283 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
284 | # TODO: add support for attn.scale when we move to Torch 2.1
285 | hidden_states = F.scaled_dot_product_attention(
286 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
287 | )
288 |
289 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
290 | hidden_states = hidden_states.to(query.dtype)
291 |
292 | # linear proj
293 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
294 | # dropout
295 | hidden_states = attn.to_out[1](hidden_states)
296 |
297 | if input_ndim == 4:
298 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
299 |
300 | if attn.residual_connection:
301 | hidden_states = hidden_states + residual
302 |
303 | hidden_states = hidden_states / attn.rescale_output_factor
304 |
305 | return hidden_states
306 |
307 |
308 | class LoRAIPAttnProcessor2_0(nn.Module):
309 | r"""
310 | Processor for implementing the LoRA attention mechanism.
311 |
312 | Args:
313 | hidden_size (`int`, *optional*):
314 | The hidden size of the attention layer.
315 | cross_attention_dim (`int`, *optional*):
316 | The number of channels in the `encoder_hidden_states`.
317 | rank (`int`, defaults to 4):
318 | The dimension of the LoRA update matrices.
319 | network_alpha (`int`, *optional*):
320 | Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
321 | """
322 |
323 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
324 | super().__init__()
325 |
326 | self.rank = rank
327 | self.lora_scale = lora_scale
328 | self.num_tokens = num_tokens
329 |
330 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
331 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
332 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
333 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
334 |
335 |
336 | self.hidden_size = hidden_size
337 | self.cross_attention_dim = cross_attention_dim
338 | self.scale = scale
339 |
340 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
341 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
342 |
343 | def __call__(
344 | self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None, *args, **kwargs,
345 | ):
346 | residual = hidden_states
347 |
348 | if attn.spatial_norm is not None:
349 | hidden_states = attn.spatial_norm(hidden_states, temb)
350 |
351 | input_ndim = hidden_states.ndim
352 |
353 | if input_ndim == 4:
354 | batch_size, channel, height, width = hidden_states.shape
355 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
356 |
357 | batch_size, sequence_length, _ = (
358 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
359 | )
360 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
361 |
362 | if attn.group_norm is not None:
363 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
364 |
365 | query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
366 | #query = attn.head_to_batch_dim(query)
367 |
368 | if encoder_hidden_states is None:
369 | encoder_hidden_states = hidden_states
370 | else:
371 | # get encoder_hidden_states, ip_hidden_states
372 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
373 | encoder_hidden_states, ip_hidden_states = (
374 | encoder_hidden_states[:, :end_pos, :],
375 | encoder_hidden_states[:, end_pos:, :],
376 | )
377 | if attn.norm_cross:
378 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
379 |
380 | # for text
381 | key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
382 | value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
383 |
384 | inner_dim = key.shape[-1]
385 | head_dim = inner_dim // attn.heads
386 |
387 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
388 |
389 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
390 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391 |
392 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
393 | # TODO: add support for attn.scale when we move to Torch 2.1
394 | hidden_states = F.scaled_dot_product_attention(
395 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
396 | )
397 |
398 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
399 | hidden_states = hidden_states.to(query.dtype)
400 |
401 | # for ip
402 | ip_key = self.to_k_ip(ip_hidden_states)
403 | ip_value = self.to_v_ip(ip_hidden_states)
404 |
405 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
406 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
407 |
408 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
409 | # TODO: add support for attn.scale when we move to Torch 2.1
410 | ip_hidden_states = F.scaled_dot_product_attention(
411 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
412 | )
413 |
414 |
415 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
416 | ip_hidden_states = ip_hidden_states.to(query.dtype)
417 |
418 | hidden_states = hidden_states + self.scale * ip_hidden_states
419 |
420 | # linear proj
421 | hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
422 | # dropout
423 | hidden_states = attn.to_out[1](hidden_states)
424 |
425 | if input_ndim == 4:
426 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
427 |
428 | if attn.residual_connection:
429 | hidden_states = hidden_states + residual
430 |
431 | hidden_states = hidden_states / attn.rescale_output_factor
432 |
433 | return hidden_states
434 |
--------------------------------------------------------------------------------
/ip_adapter/ip_adapter.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 |
4 | import torch
5 | from diffusers import StableDiffusionPipeline
6 | from diffusers.pipelines.controlnet import MultiControlNetModel
7 | from PIL import Image
8 | from safetensors import safe_open
9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10 |
11 | from .utils import is_torch2_available, get_generator
12 |
13 | if is_torch2_available():
14 | from .attention_processor import (
15 | AttnProcessor2_0 as AttnProcessor,
16 | )
17 | from .attention_processor import (
18 | CNAttnProcessor2_0 as CNAttnProcessor,
19 | )
20 | from .attention_processor import (
21 | IPAttnProcessor2_0 as IPAttnProcessor,
22 | )
23 | else:
24 | from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
25 | from .resampler import Resampler
26 |
27 |
28 | class ImageProjModel(torch.nn.Module):
29 | """Projection Model"""
30 |
31 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
32 | super().__init__()
33 |
34 | self.generator = None
35 | self.cross_attention_dim = cross_attention_dim
36 | self.clip_extra_context_tokens = clip_extra_context_tokens
37 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
38 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
39 |
40 | def forward(self, image_embeds):
41 | embeds = image_embeds
42 | clip_extra_context_tokens = self.proj(embeds).reshape(
43 | -1, self.clip_extra_context_tokens, self.cross_attention_dim
44 | )
45 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
46 | return clip_extra_context_tokens
47 |
48 |
49 | class MLPProjModel(torch.nn.Module):
50 | """SD model with image prompt"""
51 |
52 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
53 | super().__init__()
54 |
55 | self.proj = torch.nn.Sequential(
56 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
57 | torch.nn.GELU(),
58 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
59 | torch.nn.LayerNorm(cross_attention_dim),
60 | )
61 |
62 | def forward(self, image_embeds):
63 | clip_extra_context_tokens = self.proj(image_embeds)
64 | return clip_extra_context_tokens
65 |
66 |
67 | class IPAdapter:
68 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
69 | self.device = device
70 | self.image_encoder_path = image_encoder_path
71 | self.ip_ckpt = ip_ckpt
72 | self.num_tokens = num_tokens
73 |
74 | self.pipe = sd_pipe.to(self.device)
75 | self.set_ip_adapter()
76 |
77 | # load image encoder
78 | # self.image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder").to(
79 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
80 | "h94/IP-Adapter", subfolder=image_encoder_path
81 | ).to(self.device, dtype=torch.float16)
82 | self.clip_image_processor = CLIPImageProcessor()
83 | # image proj model
84 | self.image_proj_model = self.init_proj()
85 |
86 | self.load_ip_adapter()
87 |
88 | def init_proj(self):
89 | image_proj_model = ImageProjModel(
90 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
91 | clip_embeddings_dim=self.image_encoder.config.projection_dim,
92 | clip_extra_context_tokens=self.num_tokens,
93 | ).to(self.device, dtype=torch.float16)
94 | return image_proj_model
95 |
96 | def set_ip_adapter(self):
97 | unet = self.pipe.unet
98 | attn_procs = {}
99 | for name in unet.attn_processors.keys():
100 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
101 | if name.startswith("mid_block"):
102 | hidden_size = unet.config.block_out_channels[-1]
103 | elif name.startswith("up_blocks"):
104 | block_id = int(name[len("up_blocks.")])
105 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
106 | elif name.startswith("down_blocks"):
107 | block_id = int(name[len("down_blocks.")])
108 | hidden_size = unet.config.block_out_channels[block_id]
109 | if cross_attention_dim is None:
110 | attn_procs[name] = AttnProcessor()
111 | else:
112 | attn_procs[name] = IPAttnProcessor(
113 | hidden_size=hidden_size,
114 | cross_attention_dim=cross_attention_dim,
115 | scale=1.0,
116 | num_tokens=self.num_tokens,
117 | ).to(self.device, dtype=torch.float16)
118 | unet.set_attn_processor(attn_procs)
119 | if hasattr(self.pipe, "controlnet"):
120 | if isinstance(self.pipe.controlnet, MultiControlNetModel):
121 | for controlnet in self.pipe.controlnet.nets:
122 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
123 | else:
124 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
125 |
126 | def load_ip_adapter(self):
127 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
128 | state_dict = {"image_proj": {}, "ip_adapter": {}}
129 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
130 | for key in f.keys():
131 | if key.startswith("image_proj."):
132 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
133 | elif key.startswith("ip_adapter."):
134 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
135 | else:
136 | state_dict = torch.load(self.ip_ckpt, map_location="cpu")
137 | self.image_proj_model.load_state_dict(state_dict["image_proj"])
138 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
139 | ip_layers.load_state_dict(state_dict["ip_adapter"])
140 |
141 | def save_ip_adapter(self, save_path):
142 | state_dict = {
143 | "image_proj": self.image_proj_model.state_dict(),
144 | "ip_adapter": torch.nn.ModuleList(self.pipe.unet.attn_processors.values()).state_dict(),
145 | }
146 | torch.save(state_dict, save_path)
147 |
148 | @torch.inference_mode()
149 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
150 | if pil_image is not None:
151 | if isinstance(pil_image, Image.Image):
152 | pil_image = [pil_image]
153 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
154 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
155 | else:
156 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
157 | image_prompt_embeds = self.image_proj_model(clip_image_embeds)
158 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
159 | return image_prompt_embeds, uncond_image_prompt_embeds
160 |
161 | def set_scale(self, scale):
162 | for attn_processor in self.pipe.unet.attn_processors.values():
163 | if isinstance(attn_processor, IPAttnProcessor):
164 | attn_processor.scale = scale
165 |
166 | def generate(
167 | self,
168 | pil_image=None,
169 | clip_image_embeds=None,
170 | prompt=None,
171 | negative_prompt=None,
172 | scale=1.0,
173 | num_samples=4,
174 | seed=None,
175 | guidance_scale=7.5,
176 | num_inference_steps=30,
177 | **kwargs,
178 | ):
179 | self.set_scale(scale)
180 |
181 | if pil_image is not None:
182 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
183 | else:
184 | num_prompts = clip_image_embeds.size(0)
185 |
186 | if prompt is None:
187 | prompt = "best quality, high quality"
188 | if negative_prompt is None:
189 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
190 |
191 | if not isinstance(prompt, List):
192 | prompt = [prompt] * num_prompts
193 | if not isinstance(negative_prompt, List):
194 | negative_prompt = [negative_prompt] * num_prompts
195 |
196 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
197 | pil_image=pil_image, clip_image_embeds=clip_image_embeds
198 | )
199 | bs_embed, seq_len, _ = image_prompt_embeds.shape
200 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
201 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
202 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
203 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
204 |
205 | with torch.inference_mode():
206 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
207 | prompt,
208 | device=self.device,
209 | num_images_per_prompt=num_samples,
210 | do_classifier_free_guidance=True,
211 | negative_prompt=negative_prompt,
212 | )
213 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
214 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
215 |
216 | generator = get_generator(seed, self.device)
217 |
218 | images = self.pipe(
219 | prompt_embeds=prompt_embeds,
220 | negative_prompt_embeds=negative_prompt_embeds,
221 | guidance_scale=guidance_scale,
222 | num_inference_steps=num_inference_steps,
223 | generator=generator,
224 | **kwargs,
225 | ).images
226 |
227 | return images
228 |
229 |
230 | class IPAdapterXL(IPAdapter):
231 | """SDXL"""
232 |
233 | def generate(
234 | self,
235 | pil_image,
236 | prompt=None,
237 | negative_prompt=None,
238 | scale=1.0,
239 | num_samples=4,
240 | seed=None,
241 | num_inference_steps=30,
242 | image_prompt_embeds=None,
243 | **kwargs,
244 | ):
245 | self.set_scale(scale)
246 |
247 | if pil_image is not None:
248 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
249 | else:
250 | num_prompts = 1
251 |
252 | if prompt is None:
253 | prompt = "best quality, high quality"
254 | if negative_prompt is None:
255 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
256 |
257 | if not isinstance(prompt, List):
258 | prompt = [prompt] * num_prompts
259 | if not isinstance(negative_prompt, List):
260 | negative_prompt = [negative_prompt] * num_prompts
261 |
262 | if pil_image is None:
263 | assert image_prompt_embeds is not None
264 | clip_image = self.clip_image_processor(
265 | images=[Image.new("RGB", (128, 128))], return_tensors="pt"
266 | ).pixel_values
267 | clip_image = clip_image.to(self.device, dtype=torch.float16)
268 | uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image)).image_embeds
269 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
270 | else:
271 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
272 | bs_embed, seq_len, _ = image_prompt_embeds.shape
273 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
274 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
275 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
276 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
277 |
278 | with torch.inference_mode():
279 | (
280 | prompt_embeds,
281 | negative_prompt_embeds,
282 | pooled_prompt_embeds,
283 | negative_pooled_prompt_embeds,
284 | ) = self.pipe.encode_prompt(
285 | prompt,
286 | num_images_per_prompt=num_samples,
287 | do_classifier_free_guidance=True,
288 | negative_prompt=negative_prompt,
289 | )
290 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
291 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
292 |
293 | self.generator = get_generator(seed, self.device)
294 |
295 | images = self.pipe(
296 | prompt_embeds=prompt_embeds,
297 | negative_prompt_embeds=negative_prompt_embeds,
298 | pooled_prompt_embeds=pooled_prompt_embeds,
299 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
300 | num_inference_steps=num_inference_steps,
301 | generator=self.generator,
302 | **kwargs,
303 | ).images
304 |
305 | return images
306 |
307 |
308 | class IPAdapterPlus(IPAdapter):
309 | """IP-Adapter with fine-grained features"""
310 |
311 | def init_proj(self):
312 | image_proj_model = Resampler(
313 | dim=self.pipe.unet.config.cross_attention_dim,
314 | depth=4,
315 | dim_head=64,
316 | heads=12,
317 | num_queries=self.num_tokens,
318 | embedding_dim=self.image_encoder.config.hidden_size,
319 | output_dim=self.pipe.unet.config.cross_attention_dim,
320 | ff_mult=4,
321 | ).to(self.device, dtype=torch.float16)
322 | return image_proj_model
323 |
324 | @torch.inference_mode()
325 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
326 | if isinstance(pil_image, Image.Image):
327 | pil_image = [pil_image]
328 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
329 | clip_image = clip_image.to(self.device, dtype=torch.float16)
330 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
331 | image_prompt_embeds = self.image_proj_model(clip_image_embeds)
332 | uncond_clip_image_embeds = self.image_encoder(
333 | torch.zeros_like(clip_image), output_hidden_states=True
334 | ).hidden_states[-2]
335 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
336 | return image_prompt_embeds, uncond_image_prompt_embeds
337 |
338 |
339 | class IPAdapterFull(IPAdapterPlus):
340 | """IP-Adapter with full features"""
341 |
342 | def init_proj(self):
343 | image_proj_model = MLPProjModel(
344 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
345 | clip_embeddings_dim=self.image_encoder.config.hidden_size,
346 | ).to(self.device, dtype=torch.float16)
347 | return image_proj_model
348 |
349 |
350 | class IPAdapterPlusXL(IPAdapter):
351 | """SDXL"""
352 |
353 | def init_proj(self):
354 | image_proj_model = Resampler(
355 | dim=1280,
356 | depth=4,
357 | dim_head=64,
358 | heads=20,
359 | num_queries=self.num_tokens,
360 | embedding_dim=self.image_encoder.config.hidden_size,
361 | output_dim=self.pipe.unet.config.cross_attention_dim,
362 | ff_mult=4,
363 | ).to(self.device, dtype=torch.float16)
364 | return image_proj_model
365 |
366 | @torch.inference_mode()
367 | def get_image_embeds(self, pil_image, skip_uncond=False):
368 | if isinstance(pil_image, Image.Image):
369 | pil_image = [pil_image]
370 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
371 | else:
372 | clip_image = pil_image
373 | clip_image = clip_image.to(self.device, dtype=torch.float16)
374 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
375 | image_prompt_embeds = self.image_proj_model(clip_image_embeds)
376 | if skip_uncond:
377 | return image_prompt_embeds
378 | uncond_clip_image_embeds = self.image_encoder(
379 | torch.zeros_like(clip_image), output_hidden_states=True
380 | ).hidden_states[-2]
381 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
382 | return image_prompt_embeds, uncond_image_prompt_embeds
383 |
384 | def get_uncond_embeds(self):
385 | clip_image = self.clip_image_processor(images=[Image.new("RGB", (128, 128))], return_tensors="pt").pixel_values
386 | clip_image = clip_image.to(self.device, dtype=torch.float16)
387 | uncond_clip_image_embeds = self.image_encoder(
388 | torch.zeros_like(clip_image), output_hidden_states=True
389 | ).hidden_states[-2]
390 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
391 | return uncond_image_prompt_embeds
392 |
393 | def generate(
394 | self,
395 | pil_image=None,
396 | prompt=None,
397 | negative_prompt=None,
398 | scale=1.0,
399 | num_samples=4,
400 | seed=None,
401 | num_inference_steps=30,
402 | image_prompt_embeds=None,
403 | **kwargs,
404 | ):
405 | self.set_scale(scale)
406 |
407 | num_prompts = 1 # if isinstance(pil_image, Image.Image) else len(pil_image)
408 |
409 | if prompt is None:
410 | prompt = "best quality, high quality"
411 | if negative_prompt is None:
412 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
413 |
414 | if not isinstance(prompt, List):
415 | prompt = [prompt] * num_prompts
416 | if not isinstance(negative_prompt, List):
417 | negative_prompt = [negative_prompt] * num_prompts
418 |
419 | if image_prompt_embeds is None:
420 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
421 | else:
422 | uncond_image_prompt_embeds = self.get_uncond_embeds()
423 |
424 | bs_embed, seq_len, _ = image_prompt_embeds.shape
425 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
426 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
427 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
428 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
429 |
430 | with torch.inference_mode():
431 | (
432 | prompt_embeds,
433 | negative_prompt_embeds,
434 | pooled_prompt_embeds,
435 | negative_pooled_prompt_embeds,
436 | ) = self.pipe.encode_prompt(
437 | prompt,
438 | num_images_per_prompt=num_samples,
439 | do_classifier_free_guidance=True,
440 | negative_prompt=negative_prompt,
441 | )
442 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
443 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
444 |
445 | generator = get_generator(seed, self.device)
446 |
447 | images = self.pipe(
448 | prompt_embeds=prompt_embeds,
449 | negative_prompt_embeds=negative_prompt_embeds,
450 | pooled_prompt_embeds=pooled_prompt_embeds,
451 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
452 | num_inference_steps=num_inference_steps,
453 | generator=generator,
454 | **kwargs,
455 | ).images
456 |
457 | return images
458 |
--------------------------------------------------------------------------------
/ip_adapter/resampler.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | from einops import rearrange
9 | from einops.layers.torch import Rearrange
10 |
11 |
12 | # FFN
13 | def FeedForward(dim, mult=4):
14 | inner_dim = int(dim * mult)
15 | return nn.Sequential(
16 | nn.LayerNorm(dim),
17 | nn.Linear(dim, inner_dim, bias=False),
18 | nn.GELU(),
19 | nn.Linear(inner_dim, dim, bias=False),
20 | )
21 |
22 |
23 | def reshape_tensor(x, heads):
24 | bs, length, width = x.shape
25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26 | x = x.view(bs, length, heads, -1)
27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28 | x = x.transpose(1, 2)
29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30 | x = x.reshape(bs, heads, length, -1)
31 | return x
32 |
33 |
34 | class PerceiverAttention(nn.Module):
35 | def __init__(self, *, dim, dim_head=64, heads=8):
36 | super().__init__()
37 | self.scale = dim_head**-0.5
38 | self.dim_head = dim_head
39 | self.heads = heads
40 | inner_dim = dim_head * heads
41 |
42 | self.norm1 = nn.LayerNorm(dim)
43 | self.norm2 = nn.LayerNorm(dim)
44 |
45 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
48 |
49 | def forward(self, x, latents):
50 | """
51 | Args:
52 | x (torch.Tensor): image features
53 | shape (b, n1, D)
54 | latent (torch.Tensor): latent features
55 | shape (b, n2, D)
56 | """
57 | x = self.norm1(x)
58 | latents = self.norm2(latents)
59 |
60 | b, l, _ = latents.shape
61 |
62 | q = self.to_q(latents)
63 | kv_input = torch.cat((x, latents), dim=-2)
64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65 |
66 | q = reshape_tensor(q, self.heads)
67 | k = reshape_tensor(k, self.heads)
68 | v = reshape_tensor(v, self.heads)
69 |
70 | # attention
71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74 | out = weight @ v
75 |
76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77 |
78 | return self.to_out(out)
79 |
80 |
81 | class Resampler(nn.Module):
82 | def __init__(
83 | self,
84 | dim=1024,
85 | depth=8,
86 | dim_head=64,
87 | heads=16,
88 | num_queries=8,
89 | embedding_dim=768,
90 | output_dim=1024,
91 | ff_mult=4,
92 | max_seq_len: int = 257, # CLIP tokens + CLS token
93 | apply_pos_emb: bool = False,
94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95 | ):
96 | super().__init__()
97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98 |
99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100 |
101 | self.proj_in = nn.Linear(embedding_dim, dim)
102 |
103 | self.proj_out = nn.Linear(dim, output_dim)
104 | self.norm_out = nn.LayerNorm(output_dim)
105 |
106 | self.to_latents_from_mean_pooled_seq = (
107 | nn.Sequential(
108 | nn.LayerNorm(dim),
109 | nn.Linear(dim, dim * num_latents_mean_pooled),
110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111 | )
112 | if num_latents_mean_pooled > 0
113 | else None
114 | )
115 |
116 | self.layers = nn.ModuleList([])
117 | for _ in range(depth):
118 | self.layers.append(
119 | nn.ModuleList(
120 | [
121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122 | FeedForward(dim=dim, mult=ff_mult),
123 | ]
124 | )
125 | )
126 |
127 | def forward(self, x):
128 | if self.pos_emb is not None:
129 | n, device = x.shape[1], x.device
130 | pos_emb = self.pos_emb(torch.arange(n, device=device))
131 | x = x + pos_emb
132 |
133 | latents = self.latents.repeat(x.size(0), 1, 1)
134 |
135 | x = self.proj_in(x)
136 |
137 | if self.to_latents_from_mean_pooled_seq:
138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140 | latents = torch.cat((meanpooled_latents, latents), dim=-2)
141 |
142 | for attn, ff in self.layers:
143 | latents = attn(x, latents) + latents
144 | latents = ff(latents) + latents
145 |
146 | latents = self.proj_out(latents)
147 | return self.norm_out(latents)
148 |
149 |
150 | def masked_mean(t, *, dim, mask=None):
151 | if mask is None:
152 | return t.mean(dim=dim)
153 |
154 | denom = mask.sum(dim=dim, keepdim=True)
155 | mask = rearrange(mask, "b n -> b n 1")
156 | masked_t = t.masked_fill(~mask, 0.0)
157 |
158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
159 |
--------------------------------------------------------------------------------
/ip_adapter/sd3_attention_processor.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Union
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn
6 | from diffusers.models.attention_processor import Attention
7 |
8 |
9 | class JointAttnProcessor2_0:
10 | """Attention processor used typically in processing the SD3-like self-attention projections."""
11 |
12 | def __init__(self):
13 | if not hasattr(F, "scaled_dot_product_attention"):
14 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
15 |
16 | def __call__(
17 | self,
18 | attn: Attention,
19 | hidden_states: torch.FloatTensor,
20 | encoder_hidden_states: torch.FloatTensor = None,
21 | attention_mask: Optional[torch.FloatTensor] = None,
22 | *args,
23 | **kwargs,
24 | ) -> torch.FloatTensor:
25 | residual = hidden_states
26 |
27 | input_ndim = hidden_states.ndim
28 | if input_ndim == 4:
29 | batch_size, channel, height, width = hidden_states.shape
30 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
31 | context_input_ndim = encoder_hidden_states.ndim
32 | if context_input_ndim == 4:
33 | batch_size, channel, height, width = encoder_hidden_states.shape
34 | encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
35 |
36 | batch_size = encoder_hidden_states.shape[0]
37 |
38 | # `sample` projections.
39 | query = attn.to_q(hidden_states)
40 | key = attn.to_k(hidden_states)
41 | value = attn.to_v(hidden_states)
42 |
43 | # `context` projections.
44 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
45 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
46 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
47 |
48 | # attention
49 | query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
50 | key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
51 | value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
52 |
53 | inner_dim = key.shape[-1]
54 | head_dim = inner_dim // attn.heads
55 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
56 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
57 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
58 |
59 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
60 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
61 | hidden_states = hidden_states.to(query.dtype)
62 |
63 | # Split the attention outputs.
64 | hidden_states, encoder_hidden_states = (
65 | hidden_states[:, : residual.shape[1]],
66 | hidden_states[:, residual.shape[1] :],
67 | )
68 |
69 | # linear proj
70 | hidden_states = attn.to_out[0](hidden_states)
71 | # dropout
72 | hidden_states = attn.to_out[1](hidden_states)
73 | if not attn.context_pre_only:
74 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
75 |
76 | if input_ndim == 4:
77 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
78 | if context_input_ndim == 4:
79 | encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
80 |
81 | return hidden_states, encoder_hidden_states
82 |
83 |
84 | class IPJointAttnProcessor2_0(torch.nn.Module):
85 | """Attention processor used typically in processing the SD3-like self-attention projections."""
86 |
87 | def __init__(self, context_dim, hidden_dim, scale=1.0):
88 | if not hasattr(F, "scaled_dot_product_attention"):
89 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
90 | super().__init__()
91 | self.scale = scale
92 |
93 | self.add_k_proj_ip = nn.Linear(context_dim, hidden_dim)
94 | self.add_v_proj_ip = nn.Linear(context_dim, hidden_dim)
95 |
96 |
97 | def __call__(
98 | self,
99 | attn: Attention,
100 | hidden_states: torch.FloatTensor,
101 | encoder_hidden_states: torch.FloatTensor = None,
102 | attention_mask: Optional[torch.FloatTensor] = None,
103 | ip_hidden_states: torch.FloatTensor = None,
104 | *args,
105 | **kwargs,
106 | ) -> torch.FloatTensor:
107 | residual = hidden_states
108 |
109 | input_ndim = hidden_states.ndim
110 | if input_ndim == 4:
111 | batch_size, channel, height, width = hidden_states.shape
112 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
113 | context_input_ndim = encoder_hidden_states.ndim
114 | if context_input_ndim == 4:
115 | batch_size, channel, height, width = encoder_hidden_states.shape
116 | encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
117 |
118 | batch_size = encoder_hidden_states.shape[0]
119 |
120 | # `sample` projections.
121 | query = attn.to_q(hidden_states)
122 | key = attn.to_k(hidden_states)
123 | value = attn.to_v(hidden_states)
124 |
125 | sample_query = query # latent query
126 |
127 | # `context` projections.
128 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
129 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
130 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
131 |
132 | # attention
133 | query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
134 | key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
135 | value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
136 |
137 | inner_dim = key.shape[-1]
138 | head_dim = inner_dim // attn.heads
139 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
140 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
141 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
142 |
143 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
144 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
145 | hidden_states = hidden_states.to(query.dtype)
146 |
147 | # Split the attention outputs.
148 | hidden_states, encoder_hidden_states = (
149 | hidden_states[:, : residual.shape[1]],
150 | hidden_states[:, residual.shape[1] :],
151 | )
152 |
153 | # for ip-adapter
154 | ip_key = self.add_k_proj_ip(ip_hidden_states)
155 | ip_value = self.add_v_proj_ip(ip_hidden_states)
156 | ip_query = sample_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
157 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
158 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
159 |
160 | ip_hidden_states = F.scaled_dot_product_attention(ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False)
161 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
162 | ip_hidden_states = ip_hidden_states.to(ip_query.dtype)
163 |
164 | hidden_states = hidden_states + self.scale * ip_hidden_states
165 |
166 | # linear proj
167 | hidden_states = attn.to_out[0](hidden_states)
168 | # dropout
169 | hidden_states = attn.to_out[1](hidden_states)
170 | if not attn.context_pre_only:
171 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
172 |
173 | if input_ndim == 4:
174 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
175 | if context_input_ndim == 4:
176 | encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
177 |
178 | return hidden_states, encoder_hidden_states
179 |
180 |
--------------------------------------------------------------------------------
/ip_adapter/test_resampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from resampler import Resampler
3 | from transformers import CLIPVisionModel
4 |
5 | BATCH_SIZE = 2
6 | OUTPUT_DIM = 1280
7 | NUM_QUERIES = 8
8 | NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
9 | APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
10 | IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
11 |
12 |
13 | def main():
14 | image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
15 | embedding_dim = image_encoder.config.hidden_size
16 | print(f"image_encoder hidden size: ", embedding_dim)
17 |
18 | image_proj_model = Resampler(
19 | dim=1024,
20 | depth=2,
21 | dim_head=64,
22 | heads=16,
23 | num_queries=NUM_QUERIES,
24 | embedding_dim=embedding_dim,
25 | output_dim=OUTPUT_DIM,
26 | ff_mult=2,
27 | max_seq_len=257,
28 | apply_pos_emb=APPLY_POS_EMB,
29 | num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
30 | )
31 |
32 | dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
33 | with torch.no_grad():
34 | image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
35 | print("image_embds shape: ", image_embeds.shape)
36 |
37 | with torch.no_grad():
38 | ip_tokens = image_proj_model(image_embeds)
39 | print("ip_tokens shape:", ip_tokens.shape)
40 | assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/ip_adapter/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from PIL import Image
5 |
6 | attn_maps = {}
7 | def hook_fn(name):
8 | def forward_hook(module, input, output):
9 | if hasattr(module.processor, "attn_map"):
10 | attn_maps[name] = module.processor.attn_map
11 | del module.processor.attn_map
12 |
13 | return forward_hook
14 |
15 | def register_cross_attention_hook(unet):
16 | for name, module in unet.named_modules():
17 | if name.split('.')[-1].startswith('attn2'):
18 | module.register_forward_hook(hook_fn(name))
19 |
20 | return unet
21 |
22 | def upscale(attn_map, target_size):
23 | attn_map = torch.mean(attn_map, dim=0)
24 | attn_map = attn_map.permute(1,0)
25 | temp_size = None
26 |
27 | for i in range(0,5):
28 | scale = 2 ** i
29 | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30 | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31 | break
32 |
33 | assert temp_size is not None, "temp_size cannot is None"
34 |
35 | attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36 |
37 | attn_map = F.interpolate(
38 | attn_map.unsqueeze(0).to(dtype=torch.float32),
39 | size=target_size,
40 | mode='bilinear',
41 | align_corners=False
42 | )[0]
43 |
44 | attn_map = torch.softmax(attn_map, dim=0)
45 | return attn_map
46 | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47 |
48 | idx = 0 if instance_or_negative else 1
49 | net_attn_maps = []
50 |
51 | for name, attn_map in attn_maps.items():
52 | attn_map = attn_map.cpu() if detach else attn_map
53 | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54 | attn_map = upscale(attn_map, image_size)
55 | net_attn_maps.append(attn_map)
56 |
57 | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58 |
59 | return net_attn_maps
60 |
61 | def attnmaps2images(net_attn_maps):
62 |
63 | #total_attn_scores = 0
64 | images = []
65 |
66 | for attn_map in net_attn_maps:
67 | attn_map = attn_map.cpu().numpy()
68 | #total_attn_scores += attn_map.mean().item()
69 |
70 | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71 | normalized_attn_map = normalized_attn_map.astype(np.uint8)
72 | #print("norm: ", normalized_attn_map.shape)
73 | image = Image.fromarray(normalized_attn_map)
74 |
75 | #image = fix_save_attn_map(attn_map)
76 | images.append(image)
77 |
78 | #print(total_attn_scores)
79 | return images
80 | def is_torch2_available():
81 | return hasattr(F, "scaled_dot_product_attention")
82 |
83 | def get_generator(seed, device):
84 |
85 | if seed is not None:
86 | if isinstance(seed, list):
87 | generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
88 | else:
89 | generator = torch.Generator(device).manual_seed(seed)
90 | else:
91 | generator = None
92 |
93 | return generator
--------------------------------------------------------------------------------
/ip_lora_inference/download_ip_adapter.sh:
--------------------------------------------------------------------------------
1 | huggingface-cli download h94/IP-Adapter --repo-type model --include "sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" --local-dir ./weights/ip_adapter/
2 |
--------------------------------------------------------------------------------
/ip_lora_inference/download_loras.sh:
--------------------------------------------------------------------------------
1 | echo "using HF_API_TOKEN: $HF_API_TOKEN"
2 |
3 | huggingface-cli download kfirgold99/Piece-it-Together --repo-type model --include "background_generation/pytorch_lora_weights.safetensors" --local-dir ./weights/
4 | huggingface-cli download kfirgold99/Piece-it-Together --repo-type model --include "character_sheet/pytorch_lora_weights.safetensors" --local-dir ./weights/
--------------------------------------------------------------------------------
/ip_lora_inference/inference_ip_lora.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from pathlib import Path
3 | from typing import Optional, Union
4 |
5 | import matplotlib.pyplot as plt
6 | import pyrallis
7 | import torch
8 | from diffusers import UNet2DConditionModel
9 | from PIL import Image
10 | from torchvision import transforms
11 | from transformers import AutoModelForImageSegmentation
12 |
13 | from ip_lora_train.ip_adapter_for_lora import IPAdapterPlusXLLoRA
14 | from ip_lora_train.sdxl_ip_lora_pipeline import StableDiffusionXLIPLoRAPipeline
15 |
16 |
17 | @dataclass
18 | class ExperimentConfig:
19 | lora_type: str = "character_sheet"
20 | lora_path: Path = Path("weights/character_sheet/pytorch_lora_weights.safetensors")
21 | prompt: str = "a character sheet displaying a creature, from several angles with 1 large front view in the middle, clean white background. In the background we can see half-completed, partially colored, sketches of different parts of the object"
22 | output_dir: Path = Path("ip_lora_inference/character_sheet/")
23 | ref_images_paths: Union[list[Path], Path] = Path("assets/character_sheet_default_ref.jpg")
24 | ip_adapter_path: Path = Path("weights/ip_adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin")
25 | seed: Optional[int] = None
26 | num_inference_steps: int = 50
27 | remove_background: bool = True
28 |
29 | def __post_init__(self):
30 | assert self.lora_type in ["character_sheet", "background_generation"]
31 | assert self.lora_path.exists(), f"Lora path {self.lora_path} does not exist"
32 | assert self.ip_adapter_path.exists(), f"IP adapter path {self.ip_adapter_path} does not exist"
33 | self.output_dir.mkdir(parents=True, exist_ok=True)
34 | if isinstance(self.ref_images_paths, Path):
35 | self.ref_images_paths = [self.ref_images_paths]
36 | for ref_image_path in self.ref_images_paths:
37 | assert ref_image_path.exists(), f"Reference image path {ref_image_path} does not exist"
38 |
39 |
40 | def load_rmbg_model():
41 | model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
42 | torch.set_float32_matmul_precision(["high", "highest"][0])
43 | model.to("cuda")
44 | model.eval()
45 | return model
46 |
47 |
48 | def remove_background(model, image: Image.Image) -> Image.Image:
49 | assert image.size == (1024, 1024)
50 | # Data settings
51 | image_size = (1024, 1024)
52 | transform_image = transforms.Compose(
53 | [
54 | transforms.Resize(image_size),
55 | transforms.ToTensor(),
56 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
57 | ]
58 | )
59 |
60 | input_images = transform_image(image).unsqueeze(0).to("cuda")
61 |
62 | # Prediction
63 | with torch.no_grad():
64 | preds = model(input_images)[-1].sigmoid().cpu()
65 | pred = preds[0].squeeze()
66 | mask = transforms.ToPILImage()(pred).resize(image_size)
67 |
68 | # Create white background
69 | white_bg = Image.new("RGB", image_size, (255, 255, 255))
70 | # Paste original image using mask
71 | white_bg.paste(image, mask=mask)
72 | return white_bg
73 |
74 |
75 | @torch.inference_mode()
76 | @pyrallis.wrap()
77 | def main(cfg: ExperimentConfig):
78 | pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
79 | print("loading unet")
80 | unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet")
81 | print("loading ip model")
82 | ip_model = IPAdapterPlusXLLoRA(
83 | unet=unet,
84 | image_encoder_path="models/image_encoder",
85 | ip_ckpt=cfg.ip_adapter_path,
86 | device="cuda",
87 | num_tokens=16,
88 | )
89 | print("loading pipeline")
90 | pipe = StableDiffusionXLIPLoRAPipeline.from_pretrained(
91 | pipe_id,
92 | unet=unet,
93 | variant=None,
94 | torch_dtype=torch.float32,
95 | )
96 | print("loading lora")
97 | pipe.load_lora_weights(
98 | pretrained_model_name_or_path_or_dict=cfg.lora_path,
99 | adapter_name="lora1",
100 | )
101 | pipe.set_adapters(["lora1"], adapter_weights=[1.0])
102 |
103 | pipe.to("cuda")
104 | print("running inference")
105 | if cfg.remove_background:
106 | rmbg_model = load_rmbg_model()
107 | else:
108 | rmbg_model = None
109 | for ref_image_path in cfg.ref_images_paths:
110 | if cfg.seed is not None:
111 | generator = torch.Generator("cuda").manual_seed(cfg.seed)
112 | else:
113 | generator = None
114 | ref_image = Image.open(ref_image_path).convert("RGB")
115 | if cfg.remove_background:
116 | rmbg_model.cuda()
117 | ref_image = remove_background(model=rmbg_model, image=ref_image)
118 | rmbg_model.cpu()
119 | image_name = ref_image_path.stem
120 | image = pipe(
121 | cfg.prompt,
122 | ip_adapter_image=ref_image,
123 | ip_model=ip_model,
124 | num_inference_steps=cfg.num_inference_steps,
125 | generator=generator,
126 | ).images[0]
127 | image.save(cfg.output_dir / f"{image_name}_pred.jpg")
128 | ref_image.save(cfg.output_dir / f"{image_name}_ref.jpg")
129 |
130 | # Create side-by-side plot
131 | fig, ax = plt.subplots(1, 2, figsize=(20, 10))
132 | ax[0].imshow(ref_image)
133 | ax[0].axis("off")
134 | ax[1].imshow(image)
135 | ax[1].axis("off")
136 | fig.suptitle(cfg.prompt, fontsize=24, wrap=True)
137 | plt.tight_layout()
138 | plt.savefig(cfg.output_dir / f"{image_name}_side_by_side.jpeg", bbox_inches="tight", dpi=300)
139 | plt.close()
140 |
141 |
142 | if __name__ == "__main__":
143 | main()
144 |
--------------------------------------------------------------------------------
/ip_lora_train/ip_adapter_for_lora.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 |
4 | import torch
5 | from PIL import Image
6 | from safetensors import safe_open
7 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
8 |
9 | from ip_adapter.ip_adapter import ImageProjModel
10 | from ip_adapter.resampler import Resampler
11 | from ip_adapter.utils import get_generator, is_torch2_available
12 |
13 | if is_torch2_available():
14 | from ip_adapter.attention_processor import (
15 | AttnProcessor2_0 as AttnProcessor,
16 | )
17 | from ip_adapter.attention_processor import (
18 | CNAttnProcessor2_0 as CNAttnProcessor,
19 | )
20 | from ip_adapter.attention_processor import (
21 | IPAttnProcessor2_0 as IPAttnProcessor,
22 | )
23 | else:
24 | from ip_adapter.attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
25 |
26 | WEIGHT_DTYPE = torch.float32
27 |
28 | class IPAdapterLoRA:
29 | def __init__(self, unet, image_encoder_path, ip_ckpt, device, num_tokens=4):
30 | self.device = device
31 | self.image_encoder_path = image_encoder_path
32 | self.ip_ckpt = ip_ckpt
33 | self.num_tokens = num_tokens
34 |
35 | self.unet = unet
36 | self.set_ip_adapter()
37 |
38 | # load image encoder
39 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
40 | "h94/IP-Adapter", subfolder="models/image_encoder"
41 | ).to(self.device, dtype=WEIGHT_DTYPE)
42 | self.clip_image_processor = CLIPImageProcessor()
43 | # image proj model
44 | self.image_proj_model = self.init_proj()
45 |
46 | self.load_ip_adapter()
47 |
48 | def init_proj(self):
49 | image_proj_model = ImageProjModel(
50 | cross_attention_dim=self.unet.config.cross_attention_dim,
51 | clip_embeddings_dim=self.image_encoder.config.projection_dim,
52 | clip_extra_context_tokens=self.num_tokens,
53 | ).to(self.device, dtype=WEIGHT_DTYPE)
54 | return image_proj_model
55 |
56 | def set_ip_adapter(self):
57 | unet = self.unet
58 | attn_procs = {}
59 | for name in unet.attn_processors.keys():
60 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
61 | if name.startswith("mid_block"):
62 | hidden_size = unet.config.block_out_channels[-1]
63 | elif name.startswith("up_blocks"):
64 | block_id = int(name[len("up_blocks.")])
65 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
66 | elif name.startswith("down_blocks"):
67 | block_id = int(name[len("down_blocks.")])
68 | hidden_size = unet.config.block_out_channels[block_id]
69 | if cross_attention_dim is None:
70 | attn_procs[name] = AttnProcessor()
71 | else:
72 | attn_procs[name] = IPAttnProcessor(
73 | hidden_size=hidden_size,
74 | cross_attention_dim=cross_attention_dim,
75 | scale=1.0,
76 | num_tokens=self.num_tokens,
77 | ).to(self.device, dtype=WEIGHT_DTYPE)
78 | unet.set_attn_processor(attn_procs)
79 |
80 | def load_ip_adapter(self):
81 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
82 | state_dict = {"image_proj": {}, "ip_adapter": {}}
83 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
84 | for key in f.keys():
85 | if key.startswith("image_proj."):
86 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
87 | elif key.startswith("ip_adapter."):
88 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
89 | else:
90 | state_dict = torch.load(self.ip_ckpt, map_location="cpu")
91 | self.image_proj_model.load_state_dict(state_dict["image_proj"])
92 | ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
93 | ip_layers.load_state_dict(state_dict["ip_adapter"])
94 |
95 | def save_ip_adapter(self, save_path):
96 | state_dict = {
97 | "image_proj": self.image_proj_model.state_dict(),
98 | "ip_adapter": torch.nn.ModuleList(self.unet.attn_processors.values()).state_dict(),
99 | }
100 | torch.save(state_dict, save_path)
101 |
102 | @torch.inference_mode()
103 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
104 | if pil_image is not None:
105 | if isinstance(pil_image, Image.Image):
106 | pil_image = [pil_image]
107 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
108 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=WEIGHT_DTYPE)).image_embeds
109 | else:
110 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=WEIGHT_DTYPE)
111 | image_prompt_embeds = self.image_proj_model(clip_image_embeds)
112 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
113 | return image_prompt_embeds, uncond_image_prompt_embeds
114 |
115 | def set_scale(self, scale):
116 | for attn_processor in self.unet.attn_processors.values():
117 | if isinstance(attn_processor, IPAttnProcessor):
118 | attn_processor.scale = scale
119 |
120 | def generate(
121 | self,
122 | pil_image=None,
123 | clip_image_embeds=None,
124 | prompt=None,
125 | negative_prompt=None,
126 | scale=1.0,
127 | num_samples=4,
128 | seed=None,
129 | guidance_scale=7.5,
130 | num_inference_steps=30,
131 | **kwargs,
132 | ):
133 | self.set_scale(scale)
134 |
135 | if pil_image is not None:
136 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
137 | else:
138 | num_prompts = clip_image_embeds.size(0)
139 |
140 | if prompt is None:
141 | prompt = "best quality, high quality"
142 | if negative_prompt is None:
143 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
144 |
145 | if not isinstance(prompt, List):
146 | prompt = [prompt] * num_prompts
147 | if not isinstance(negative_prompt, List):
148 | negative_prompt = [negative_prompt] * num_prompts
149 |
150 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
151 | pil_image=pil_image, clip_image_embeds=clip_image_embeds
152 | )
153 | bs_embed, seq_len, _ = image_prompt_embeds.shape
154 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
155 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
156 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
157 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
158 |
159 | with torch.inference_mode():
160 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
161 | prompt,
162 | device=self.device,
163 | num_images_per_prompt=num_samples,
164 | do_classifier_free_guidance=True,
165 | negative_prompt=negative_prompt,
166 | )
167 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
168 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
169 |
170 | generator = get_generator(seed, self.device)
171 |
172 | images = self.pipe(
173 | prompt_embeds=prompt_embeds,
174 | negative_prompt_embeds=negative_prompt_embeds,
175 | guidance_scale=guidance_scale,
176 | num_inference_steps=num_inference_steps,
177 | generator=generator,
178 | **kwargs,
179 | ).images
180 |
181 | return images
182 |
183 |
184 | class IPAdapterPlusXLLoRA(IPAdapterLoRA):
185 | """SDXL"""
186 |
187 | def init_proj(self):
188 | image_proj_model = Resampler(
189 | dim=1280,
190 | depth=4,
191 | dim_head=64,
192 | heads=20,
193 | num_queries=self.num_tokens,
194 | embedding_dim=self.image_encoder.config.hidden_size,
195 | output_dim=self.unet.config.cross_attention_dim,
196 | ff_mult=4,
197 | ).to(self.device, dtype=WEIGHT_DTYPE)
198 | return image_proj_model
199 |
200 | @torch.inference_mode()
201 | def get_image_embeds(self, pil_image, skip_uncond=False):
202 | if isinstance(pil_image, Image.Image):
203 | pil_image = [pil_image]
204 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
205 | else:
206 | clip_image = pil_image
207 | clip_image = clip_image.to(self.device, dtype=WEIGHT_DTYPE)
208 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
209 | image_prompt_embeds = self.image_proj_model(clip_image_embeds)
210 | if skip_uncond:
211 | return image_prompt_embeds
212 | uncond_clip_image_embeds = self.image_encoder(
213 | torch.zeros_like(clip_image), output_hidden_states=True
214 | ).hidden_states[-2]
215 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
216 | return image_prompt_embeds, uncond_image_prompt_embeds
217 |
218 | def get_uncond_embeds(self):
219 | clip_image = self.clip_image_processor(images=[Image.new("RGB", (128, 128))], return_tensors="pt").pixel_values
220 | clip_image = clip_image.to(self.device, dtype=WEIGHT_DTYPE)
221 | uncond_clip_image_embeds = self.image_encoder(
222 | torch.zeros_like(clip_image), output_hidden_states=True
223 | ).hidden_states[-2]
224 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
225 | return uncond_image_prompt_embeds
226 |
227 | def generate(
228 | self,
229 | pil_image=None,
230 | prompt=None,
231 | negative_prompt=None,
232 | scale=1.0,
233 | num_samples=4,
234 | seed=None,
235 | num_inference_steps=30,
236 | image_prompt_embeds=None,
237 | **kwargs,
238 | ):
239 | self.set_scale(scale)
240 |
241 | num_prompts = 1 # if isinstance(pil_image, Image.Image) else len(pil_image)
242 |
243 | if prompt is None:
244 | prompt = "best quality, high quality"
245 | if negative_prompt is None:
246 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
247 |
248 | if not isinstance(prompt, List):
249 | prompt = [prompt] * num_prompts
250 | if not isinstance(negative_prompt, List):
251 | negative_prompt = [negative_prompt] * num_prompts
252 |
253 | if image_prompt_embeds is None:
254 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
255 | else:
256 | uncond_image_prompt_embeds = self.get_uncond_embeds()
257 |
258 | bs_embed, seq_len, _ = image_prompt_embeds.shape
259 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
260 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
261 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
262 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
263 |
264 | with torch.inference_mode():
265 | (
266 | prompt_embeds,
267 | negative_prompt_embeds,
268 | pooled_prompt_embeds,
269 | negative_pooled_prompt_embeds,
270 | ) = self.pipe.encode_prompt(
271 | prompt,
272 | num_images_per_prompt=num_samples,
273 | do_classifier_free_guidance=True,
274 | negative_prompt=negative_prompt,
275 | )
276 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
277 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
278 |
279 | generator = get_generator(seed, self.device)
280 |
281 | images = self.pipe(
282 | prompt_embeds=prompt_embeds,
283 | negative_prompt_embeds=negative_prompt_embeds,
284 | pooled_prompt_embeds=pooled_prompt_embeds,
285 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
286 | num_inference_steps=num_inference_steps,
287 | generator=generator,
288 | **kwargs,
289 | ).images
290 |
291 | return images
292 |
--------------------------------------------------------------------------------
/ip_lora_train/ip_lora_dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision.transforms import Compose, Normalize, ToTensor
7 |
8 | try:
9 | from torchvision.transforms import InterpolationMode
10 |
11 | BICUBIC = InterpolationMode.BICUBIC
12 | except ImportError:
13 | BICUBIC = Image.BICUBIC
14 |
15 | CHARACTER_SHEET_PROMPT = "a character sheet displaying a creature, from several angles with 1 large front view in the middle, clean white background. In the background we can see half-completed, partially colored, sketches of different parts of the object"
16 |
17 |
18 | def _transform():
19 | return Compose(
20 | [
21 | ToTensor(),
22 | Normalize(
23 | [0.5],
24 | [0.5],
25 | ),
26 | ]
27 | )
28 |
29 |
30 | class IPLoraDataset(Dataset):
31 | def __init__(
32 | self,
33 | tokenizer1,
34 | tokenizer2,
35 | image_processor,
36 | target_resolution: int = 1024,
37 | base_dir: Path = Path("dataset/"),
38 | prompt_mode: str = "character_sheet",
39 | ):
40 | super().__init__()
41 | self.base_dir = base_dir
42 | self.samples = self.init_samples(base_dir)
43 | self.tokenizer1 = tokenizer1
44 | self.tokenizer2 = tokenizer2
45 | self.image_processor = image_processor
46 | self.target_resolution = target_resolution
47 | self.prompt_mode = prompt_mode
48 |
49 | def init_samples(self, base_dir):
50 | ref_dir = base_dir / "refs"
51 | targets_dir = base_dir / "targets"
52 | prompt_dir = base_dir / "targets"
53 | ref_files = list(ref_dir.glob("*.png")) + list(ref_dir.glob("*.jpg"))
54 | targets_files = list(targets_dir.glob("*.png")) + list(targets_dir.glob("*.jpg"))
55 | prompt_files = list(prompt_dir.glob("*.txt"))
56 | ref_prefixes = [f.stem.split("_ref")[0] for f in ref_files]
57 | targets_prefixes = [f.stem for f in targets_files]
58 | prompt_prefixes = [f.stem for f in prompt_files]
59 | valid_prefixes = list(set(ref_prefixes) & set(targets_prefixes) & set(prompt_prefixes))
60 | ref_png_paths = [ref_dir / f"{prefix}_ref.png" for prefix in valid_prefixes]
61 | ref_jpg_paths = [ref_dir / f"{prefix}_ref.jpg" for prefix in valid_prefixes]
62 | targets_png_paths = [targets_dir / f"{prefix}.png" for prefix in valid_prefixes]
63 | targets_jpg_paths = [targets_dir / f"{prefix}.jpg" for prefix in valid_prefixes]
64 | prompt_paths = [prompt_dir / f"{prefix}.txt" for prefix in valid_prefixes]
65 | samples = [
66 | {
67 | "ref": ref_png_paths[i] if ref_png_paths[i].exists() else ref_jpg_paths[i],
68 | "sheet": targets_png_paths[i] if targets_png_paths[i].exists() else targets_jpg_paths[i],
69 | "prompt": prompt_paths[i],
70 | }
71 | for i in range(len(valid_prefixes))
72 | ]
73 | print(f"lora_dataset.py:: found {len(samples)} samples")
74 | return samples
75 |
76 | def __len__(self):
77 | return len(self.samples)
78 |
79 | def get_prompt(self, prompt_text):
80 | if self.prompt_mode == "character_sheet":
81 | return CHARACTER_SHEET_PROMPT
82 | elif self.prompt_mode == "creature_in_scene":
83 | return prompt_text.split(",")[0] + " an imaginary fantasy creature in" + prompt_text.split("in a")[1]
84 | else:
85 | raise ValueError(f"Prompt mode {self.prompt_mode} is not supported.")
86 |
87 | def __getitem__(self, i: int):
88 | sample = self.samples[i]
89 | ref_path = sample["ref"]
90 | sheet_path = sample["sheet"]
91 | prompt_path = sample["prompt"]
92 | ref_image = Image.open(ref_path)
93 | sheet_image = Image.open(sheet_path)
94 | prompt_text = open(prompt_path, "r").read()
95 | sample_prompt = self.get_prompt(prompt_text)
96 | out_dict = {}
97 | input_ids1 = self.tokenizer1(
98 | [sample_prompt],
99 | max_length=self.tokenizer1.model_max_length,
100 | padding="max_length",
101 | truncation=True,
102 | return_tensors="pt",
103 | ).input_ids
104 | input_ids2 = self.tokenizer2(
105 | [sample_prompt],
106 | max_length=self.tokenizer2.model_max_length,
107 | padding="max_length",
108 | truncation=True,
109 | return_tensors="pt",
110 | ).input_ids
111 | out_dict["input_ids1"] = input_ids1
112 | out_dict["input_ids2"] = input_ids2
113 |
114 | out_dict["text"] = prompt_text
115 |
116 | processed_image_prompt = self.image_processor(images=[ref_image], return_tensors="pt").pixel_values
117 | out_dict["image_prompt"] = processed_image_prompt
118 | target_image_torch = _transform()(sheet_image)
119 | out_dict["target_image"] = target_image_torch
120 | out_dict["original_sizes"] = (self.target_resolution, self.target_resolution)
121 | out_dict["crop_top_lefts"] = (0, 0)
122 | return out_dict
123 |
124 |
125 | def ip_lora_collate_fn(batch):
126 | return_batch = {}
127 | return_batch["input_ids_one"] = torch.stack([item["input_ids1"] for item in batch])
128 | return_batch["input_ids_two"] = torch.stack([item["input_ids2"] for item in batch])
129 | return_batch["text"] = [item["text"] for item in batch]
130 | image_prompt = torch.stack([item["image_prompt"] for item in batch])
131 | image_prompt = image_prompt.to(memory_format=torch.contiguous_format).float()
132 | return_batch["image_prompt"] = image_prompt
133 | target_image = torch.stack([item["target_image"] for item in batch])
134 | target_image = target_image.to(memory_format=torch.contiguous_format).float()
135 |
136 | return_batch["target_image"] = target_image
137 | original_sizes = [item["original_sizes"] for item in batch]
138 | crop_top_lefts = [item["crop_top_lefts"] for item in batch]
139 | return_batch["original_sizes"] = original_sizes
140 | return_batch["crop_top_lefts"] = crop_top_lefts
141 | return return_batch
142 |
--------------------------------------------------------------------------------
/ip_lora_train/run_example.sh:
--------------------------------------------------------------------------------
1 | python ./ip_lora_train/train_ip_lora.py \
2 | --rank 64 \
3 | --resolution 1024 \
4 | --validation_epochs 1 \
5 | --num_train_epochs 100 \
6 | --checkpointing_steps 50 \
7 | --train_batch_size 2 \
8 | --learning_rate 1e-4 \
9 | --dataloader_num_workers 1 \
10 | --gradient_accumulation_steps 8 \
11 | --dataset_base_dir \
12 | --prompt_mode character_sheet \
13 | --output_dir ./output/train_ip_lora/character_sheet
14 |
--------------------------------------------------------------------------------
/ip_plus_space_exploration/download_directions.sh:
--------------------------------------------------------------------------------
1 | huggingface-cli download kfirgold99/Piece-it-Together --repo-type model --include "ip_space_edit_directions/*" --local-dir ./weights/
--------------------------------------------------------------------------------
/ip_plus_space_exploration/download_ip_adapter.sh:
--------------------------------------------------------------------------------
1 | huggingface-cli download h94/IP-Adapter --repo-type model --include "sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" --local-dir ./weights/ip_adapter/
2 | huggingface-cli download h94/IP-Adapter --repo-type model --include "sdxl_models/ip-adapter_sdxl.bin" --local-dir ./weights/ip_adapter/
3 |
--------------------------------------------------------------------------------
/ip_plus_space_exploration/edit_by_direction.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from pathlib import Path
3 | from typing import Optional
4 |
5 | import matplotlib.pyplot as plt
6 | import pyrallis
7 | import torch
8 | from PIL import Image
9 |
10 | from ip_plus_space_exploration.ip_model_utils import get_image_ip_base_tokens, get_image_ip_plus_tokens, load_ip
11 |
12 |
13 | def get_images_with_ip_direction(ip_model, img_tokens, ip_direction, seed):
14 | scales_images = {k: [] for k in [-1, -0.5, 0, 0.5, 1]}
15 |
16 | for step_size in scales_images.keys():
17 | new_ip_tokens = img_tokens + step_size * ip_direction
18 | image = ip_model.generate(
19 | pil_image=None,
20 | scale=1.0,
21 | image_prompt_embeds=new_ip_tokens.cuda(),
22 | num_samples=1,
23 | num_inference_steps=50,
24 | seed=seed,
25 | )
26 | scales_images[step_size] = image
27 | return scales_images
28 |
29 |
30 | def get_images_with_clip_direction(ip_model, clip_features, clip_direction, seed):
31 | scales_images = {k: [] for k in [-1, -0.5, 0, 0.5, 1]}
32 |
33 | for step_size in scales_images.keys():
34 | new_clip_features = clip_features + step_size * clip_direction
35 | new_ip_tokens = ip_model.image_proj_model(new_clip_features.cuda())
36 | image = ip_model.generate(
37 | pil_image=None,
38 | scale=1.0,
39 | image_prompt_embeds=new_ip_tokens.cuda(),
40 | num_samples=1,
41 | num_inference_steps=50,
42 | seed=seed,
43 | )
44 | scales_images[step_size] = image
45 | return scales_images
46 |
47 |
48 | @dataclass
49 | class EditByDirectionConfig:
50 | ip_model_type: str
51 | image_path: Path
52 | direction_path: Path
53 | direction_type: str
54 | output_dir: Path
55 | seed: Optional[int] = None
56 |
57 | def __post_init__(self):
58 | self.output_dir.mkdir(parents=True, exist_ok=True)
59 | assert self.direction_type.lower() in ["ip", "clip"]
60 |
61 |
62 | @torch.inference_mode()
63 | @pyrallis.wrap()
64 | def main(cfg: EditByDirectionConfig):
65 | ip_model = load_ip(cfg.ip_model_type)
66 | edit_direction = torch.load(cfg.direction_path)
67 | if cfg.image_path.is_dir():
68 | image_paths = list(cfg.image_path.glob("*.jpg")) + list(cfg.image_path.glob("*.png"))
69 | else:
70 | image_paths = [cfg.image_path]
71 | for img in image_paths:
72 | if cfg.ip_model_type == "plus":
73 | img_tokens, img_clip_embeds = get_image_ip_plus_tokens(ip_model, Image.open(img))
74 | elif cfg.ip_model_type == "base":
75 | img_tokens, img_clip_embeds = get_image_ip_base_tokens(ip_model, Image.open(img))
76 | if cfg.direction_type.upper() == "IP":
77 | scales_images = get_images_with_ip_direction(
78 | ip_model=ip_model, img_tokens=img_tokens, ip_direction=edit_direction, seed=cfg.seed
79 | )
80 | elif cfg.direction_type.upper() == "CLIP":
81 | scales_images = get_images_with_clip_direction(
82 | ip_model=ip_model, clip_features=img_clip_embeds, clip_direction=edit_direction, seed=cfg.seed
83 | )
84 | # Plot all images in a single row
85 | fig, axes = plt.subplots(1, len(scales_images) + 1, figsize=(20, 5)) # Increased height from 4 to 5
86 |
87 | # Sort the scales for consistent left-to-right ordering
88 | sorted_scales = sorted(scales_images.keys())
89 |
90 | axes[0].imshow(Image.open(img))
91 | axes[0].set_title(f"original", pad=15) # Added padding to the title
92 | axes[0].axis("off")
93 |
94 | for i, scale in enumerate(sorted_scales):
95 | axes[i + 1].imshow(scales_images[scale][0])
96 | axes[i + 1].set_title(f"dist from boundary={scale}", pad=15) # Added padding to the title
97 | axes[i + 1].axis("off")
98 |
99 | plt.tight_layout(rect=[0, 0, 1, 0.95]) # Added rect parameter to leave more room at the top
100 |
101 | plt.savefig(cfg.output_dir / f"{img.stem}_{cfg.direction_type}.png")
102 | for scale in scales_images.keys():
103 | scales_images[scale][0].save(str(cfg.output_dir / f"{img.stem}_{cfg.direction_type}_{scale}.jpg"))
104 | plt.close()
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/ip_plus_space_exploration/find_direction.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from pathlib import Path
3 |
4 | import pyrallis
5 | import torch
6 | from PIL import Image
7 | from tqdm import tqdm
8 |
9 | from ip_plus_space_exploration.ip_model_utils import get_image_ip_base_tokens, get_image_ip_plus_tokens, load_ip
10 |
11 |
12 | def get_class_tokens(ip_model, class_dir, ip_model_type: str = "plus"):
13 | image_prompt_embeds_list = []
14 | clip_image_embeds_list = []
15 | images = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png"))
16 | for _, image in tqdm(enumerate(images), total=len(images)):
17 | pil_image = Image.open(image)
18 | if ip_model_type == "plus":
19 | image_prompt_embeds, clip_image_embeds = get_image_ip_plus_tokens(ip_model, pil_image)
20 | elif ip_model_type == "base":
21 | image_prompt_embeds, clip_image_embeds = get_image_ip_base_tokens(ip_model, pil_image)
22 | image_prompt_embeds_list.append(image_prompt_embeds.detach().cpu())
23 | clip_image_embeds_list.append(clip_image_embeds.detach().cpu())
24 | return torch.cat(image_prompt_embeds_list), torch.cat(clip_image_embeds_list)
25 |
26 |
27 | @dataclass
28 | class FindDirectionConfig:
29 | class1_dir: Path
30 | class2_dir: Path
31 | output_dir: Path
32 | ip_model_type: str
33 |
34 | def __post_init__(self):
35 | self.output_dir.mkdir(parents=True, exist_ok=True)
36 | assert self.ip_model_type in ["plus", "base"]
37 |
38 |
39 | @torch.inference_mode()
40 | @pyrallis.wrap()
41 | def main(cfg: FindDirectionConfig):
42 | ip_model = load_ip(cfg.ip_model_type)
43 | class_1_image_prompt_embeds, class_1_clip_image_embeds = get_class_tokens(
44 | ip_model=ip_model, class_dir=cfg.class1_dir, ip_model_type=cfg.ip_model_type
45 | )
46 | class_2_image_prompt_embeds, class_2_clip_image_embeds = get_class_tokens(
47 | ip_model=ip_model, class_dir=cfg.class2_dir, ip_model_type=cfg.ip_model_type
48 | )
49 | clip_direction = class_2_clip_image_embeds.mean(dim=0) - class_1_clip_image_embeds.mean(dim=0)
50 | ip_direction = class_2_image_prompt_embeds.mean(dim=0) - class_1_image_prompt_embeds.mean(dim=0)
51 | torch.save(clip_direction, cfg.output_dir / "clip_direction.pt")
52 | torch.save(ip_direction, cfg.output_dir / "ip_direction.pt")
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/ip_plus_space_exploration/ip_model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import StableDiffusionXLPipeline
3 |
4 | from ip_adapter.ip_adapter import IPAdapterPlusXL, IPAdapterXL
5 |
6 |
7 | def load_ip(type: str = "plus"):
8 | sdxl_pipeline = StableDiffusionXLPipeline.from_pretrained(
9 | "stabilityai/stable-diffusion-xl-base-1.0",
10 | ).to(dtype=torch.float16)
11 | if type == "plus":
12 | ip_model = IPAdapterPlusXL(
13 | sd_pipe=sdxl_pipeline,
14 | image_encoder_path="models/image_encoder",
15 | ip_ckpt="weights/ip_adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin",
16 | device="cuda",
17 | num_tokens=16,
18 | )
19 | elif type == "base":
20 | ip_model = IPAdapterXL(
21 | sd_pipe=sdxl_pipeline,
22 | image_encoder_path="sdxl_models/image_encoder",
23 | ip_ckpt="weights/ip_adapter/sdxl_models/ip-adapter_sdxl.bin",
24 | device="cuda",
25 | )
26 |
27 | return ip_model
28 |
29 |
30 | def get_image_ip_plus_tokens(ip_model, pil_image):
31 | image_processed = ip_model.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
32 | clip_image = image_processed.to("cuda", dtype=torch.float32)
33 | clip_image_embeds = ip_model.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
34 | image_prompt_embeds = ip_model.image_proj_model(clip_image_embeds)
35 | return image_prompt_embeds.cpu().detach(), clip_image_embeds.cpu().detach()
36 |
37 |
38 | def get_image_ip_base_tokens(ip_model, pil_image):
39 | image_processed = ip_model.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
40 | clip_image = image_processed.to("cuda", dtype=torch.float32)
41 | clip_image_embeds = ip_model.image_encoder(clip_image).image_embeds
42 | image_prompt_embeds = ip_model.image_proj_model(clip_image_embeds)
43 | return image_prompt_embeds.cpu().detach(), clip_image_embeds.cpu().detach()
44 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/model/__init__.py
--------------------------------------------------------------------------------
/model/dit.py:
--------------------------------------------------------------------------------
1 | # From the great https://github.com/cloneofsimo/minRF/blob/main/dit.py
2 | # Code heavily based on https://github.com/Alpha-VLLM/LLaMA2-Accessory
3 | # this is modeling code for DiT-LLaMA model
4 |
5 | import math
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from diffusers import ModelMixin, ConfigMixin
11 | from diffusers.configuration_utils import register_to_config
12 |
13 |
14 | def modulate(x, shift, scale):
15 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
16 |
17 |
18 | class TimestepEmbedder(nn.Module):
19 | def __init__(self, hidden_size, frequency_embedding_size=256):
20 | super().__init__()
21 | self.mlp = nn.Sequential(
22 | nn.Linear(frequency_embedding_size, hidden_size),
23 | nn.SiLU(),
24 | nn.Linear(hidden_size, hidden_size),
25 | )
26 | self.frequency_embedding_size = frequency_embedding_size
27 |
28 | @staticmethod
29 | def timestep_embedding(t, dim, max_period=10000):
30 | half = dim // 2
31 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half) / half).to(t.device)
32 | args = t[:, None] * freqs[None]
33 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
34 | if dim % 2:
35 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
36 | return embedding
37 |
38 | def forward(self, t):
39 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=next(self.parameters()).dtype)
40 | t_emb = self.mlp(t_freq)
41 | return t_emb
42 |
43 |
44 | class LabelEmbedder(nn.Module):
45 | def __init__(self, num_classes, hidden_size, dropout_prob):
46 | super().__init__()
47 | use_cfg_embedding = int(dropout_prob > 0)
48 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
49 | self.num_classes = num_classes
50 | self.dropout_prob = dropout_prob
51 |
52 | def token_drop(self, labels, force_drop_ids=None):
53 | if force_drop_ids is None:
54 | drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
55 | drop_ids = drop_ids.cuda()
56 | drop_ids = drop_ids.to(labels.device)
57 | else:
58 | drop_ids = force_drop_ids == 1
59 | labels = torch.where(drop_ids, self.num_classes, labels)
60 | return labels
61 |
62 | def forward(self, labels, train, force_drop_ids=None):
63 | use_dropout = self.dropout_prob > 0
64 | if (train and use_dropout) or (force_drop_ids is not None):
65 | labels = self.token_drop(labels, force_drop_ids)
66 | embeddings = self.embedding_table(labels)
67 | return embeddings
68 |
69 |
70 | class Attention(nn.Module):
71 | def __init__(self, dim, n_heads):
72 | super().__init__()
73 |
74 | self.n_heads = n_heads
75 | self.n_rep = 1
76 | self.head_dim = dim // n_heads
77 |
78 | self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
79 | self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
80 | self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
81 | self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
82 |
83 | self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
84 | self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim)
85 |
86 | @staticmethod
87 | def reshape_for_broadcast(freqs_cis, x):
88 | ndim = x.ndim
89 | assert 0 <= 1 < ndim
90 | # assert freqs_cis.shape == (x.shape[1], x.shape[-1])
91 | _freqs_cis = freqs_cis[: x.shape[1]]
92 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
93 | return _freqs_cis.view(*shape)
94 |
95 | @staticmethod
96 | def apply_rotary_emb(xq, xk, freqs_cis):
97 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
98 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
99 | freqs_cis_xq = Attention.reshape_for_broadcast(freqs_cis, xq_)
100 | freqs_cis_xk = Attention.reshape_for_broadcast(freqs_cis, xk_)
101 |
102 | xq_out = torch.view_as_real(xq_ * freqs_cis_xq).flatten(3)
103 | xk_out = torch.view_as_real(xk_ * freqs_cis_xk).flatten(3)
104 | return xq_out, xk_out
105 |
106 | def forward(self, x, freqs_cis):
107 | bsz, seqlen, _ = x.shape
108 |
109 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
110 |
111 | dtype = xq.dtype
112 |
113 | xq = self.q_norm(xq)
114 | xk = self.k_norm(xk)
115 |
116 | xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
117 | xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
118 | xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
119 |
120 | xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
121 | xq, xk = xq.to(dtype), xk.to(dtype)
122 |
123 | output = F.scaled_dot_product_attention(
124 | xq.permute(0, 2, 1, 3),
125 | xk.permute(0, 2, 1, 3),
126 | xv.permute(0, 2, 1, 3),
127 | dropout_p=0.0,
128 | is_causal=False,
129 | ).permute(0, 2, 1, 3)
130 | output = output.flatten(-2)
131 |
132 | return self.wo(output)
133 |
134 |
135 | class FeedForward(nn.Module):
136 | def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
137 | super().__init__()
138 | hidden_dim = int(2 * hidden_dim / 3)
139 | if ffn_dim_multiplier:
140 | hidden_dim = int(ffn_dim_multiplier * hidden_dim)
141 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
142 |
143 | self.w1 = nn.Linear(dim, hidden_dim, bias=False)
144 | self.w2 = nn.Linear(hidden_dim, dim, bias=False)
145 | self.w3 = nn.Linear(dim, hidden_dim, bias=False)
146 |
147 | def _forward_silu_gating(self, x1, x3):
148 | return F.silu(x1) * x3
149 |
150 | def forward(self, x):
151 | return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
152 |
153 |
154 | class TransformerBlock(nn.Module):
155 | def __init__(
156 | self,
157 | layer_id,
158 | dim,
159 | n_heads,
160 | multiple_of,
161 | ffn_dim_multiplier,
162 | norm_eps,
163 | ):
164 | super().__init__()
165 | self.dim = dim
166 | self.head_dim = dim // n_heads
167 | self.attention = Attention(dim, n_heads)
168 | self.feed_forward = FeedForward(
169 | dim=dim,
170 | hidden_dim=4 * dim,
171 | multiple_of=multiple_of,
172 | ffn_dim_multiplier=ffn_dim_multiplier,
173 | )
174 | self.layer_id = layer_id
175 | self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
176 | self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
177 |
178 | self.adaLN_modulation = nn.Sequential(
179 | nn.SiLU(),
180 | nn.Linear(min(dim, 1024), 6 * dim, bias=True),
181 | )
182 |
183 | def forward(self, x, freqs_cis, adaln_input=None):
184 | if adaln_input is not None:
185 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(
186 | 6, dim=1
187 | )
188 |
189 | x = x + gate_msa.unsqueeze(1) * self.attention(
190 | modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
191 | )
192 | x = x + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(x), shift_mlp, scale_mlp))
193 | else:
194 | x = x + self.attention(self.attention_norm(x), freqs_cis)
195 | x = x + self.feed_forward(self.ffn_norm(x))
196 |
197 | return x
198 |
199 |
200 | class FinalLayer(nn.Module):
201 | def __init__(self, hidden_size, out_channels):
202 | super().__init__()
203 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
204 | self.linear = nn.Linear(hidden_size, out_channels, bias=True)
205 | self.adaLN_modulation = nn.Sequential(
206 | nn.SiLU(),
207 | nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True),
208 | )
209 | # # init zero
210 | nn.init.constant_(self.linear.weight, 0)
211 | nn.init.constant_(self.linear.bias, 0)
212 |
213 | def forward(self, x, c):
214 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
215 | x = modulate(self.norm_final(x), shift, scale)
216 | x = self.linear(x)
217 | return x
218 |
219 |
220 | class DiT_Llama(ModelMixin, ConfigMixin):
221 |
222 | @register_to_config
223 | def __init__(
224 | self,
225 | embedding_dim=3,
226 | hidden_dim=512,
227 | n_layers=5,
228 | n_heads=16,
229 | multiple_of=256,
230 | ffn_dim_multiplier=None,
231 | norm_eps=1e-5,
232 | ):
233 | super().__init__()
234 |
235 | self.in_channels = embedding_dim
236 | self.out_channels = embedding_dim
237 |
238 | self.x_embedder = nn.Linear(embedding_dim, hidden_dim, bias=True)
239 | nn.init.constant_(self.x_embedder.bias, 0)
240 |
241 | self.t_embedder = TimestepEmbedder(min(hidden_dim, 1024))
242 | # self.y_embedder = LabelEmbedder(num_classes, min(dim, 1024), class_dropout_prob)
243 |
244 | self.layers = nn.ModuleList(
245 | [
246 | TransformerBlock(
247 | layer_id,
248 | hidden_dim,
249 | n_heads,
250 | multiple_of,
251 | ffn_dim_multiplier,
252 | norm_eps,
253 | )
254 | for layer_id in range(n_layers)
255 | ]
256 | )
257 | self.final_layer = FinalLayer(hidden_dim, self.out_channels)
258 |
259 | self.freqs_cis = DiT_Llama.precompute_freqs_cis(hidden_dim // n_heads, 4096)
260 |
261 | def forward(self, x, t, cond):
262 | self.freqs_cis = self.freqs_cis.to(x.device)
263 |
264 | x = torch.cat([x, cond], dim=1)
265 |
266 | x = self.x_embedder(x)
267 |
268 | t = self.t_embedder(t) # (N, D)
269 | adaln_input = t.to(x.dtype)
270 |
271 | for layer in self.layers:
272 | x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)
273 |
274 | x = self.final_layer(x, adaln_input)
275 | # Drop the cond part
276 | x = x[:, : -cond.size(1)]
277 | return x
278 |
279 | def forward_with_cfg(self, x, t, cond, cfg_scale):
280 | half = x[: len(x) // 2]
281 | combined = torch.cat([half, half], dim=0)
282 | model_out = self.forward(combined, t, cond)
283 | eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
284 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
285 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
286 | eps = torch.cat([half_eps, half_eps], dim=0)
287 | return torch.cat([eps, rest], dim=1)
288 |
289 | @staticmethod
290 | def precompute_freqs_cis(dim, end, theta=10000.0):
291 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
292 | t = torch.arange(end)
293 | freqs = torch.outer(t, freqs).float()
294 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
295 | return freqs_cis
296 |
297 |
298 | def DiT_base(**kwargs):
299 | return DiT_Llama(in_dim=2048, hidden_dim=2048, n_layers=8, n_heads=32, **kwargs)
300 |
301 |
302 | if __name__ == "__main__":
303 | model = DiT_Llama_600M_patch2()
304 | model.eval()
305 | x = torch.randn(2, 3, 32, 32)
306 | t = torch.randint(0, 100, (2,))
307 | y = torch.randint(0, 10, (2,))
308 |
309 | with torch.no_grad():
310 | out = model(x, t, y)
311 | print(out.shape)
312 | out = model.forward_with_cfg(x, t, y, 0.5)
313 | print(out.shape)
314 |
--------------------------------------------------------------------------------
/model/pipeline_pit.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Optional, Union
3 |
4 | import torch
5 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
6 | from diffusers.utils import BaseOutput
7 | from diffusers.utils import (
8 | logging,
9 | )
10 | from diffusers.utils.torch_utils import randn_tensor
11 | from dataclasses import dataclass
12 | from model.dit import DiT_Llama
13 |
14 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
15 |
16 |
17 | @dataclass
18 | class PiTPipelineOutput(BaseOutput):
19 | image_embeds: torch.Tensor
20 |
21 |
22 | class PiTPipeline(DiffusionPipeline):
23 |
24 | def __init__(self, prior: DiT_Llama):
25 | super().__init__()
26 |
27 | self.register_modules(
28 | prior=prior,
29 | )
30 |
31 | def prepare_latents(self, shape, dtype, device, generator, latents):
32 | if latents is None:
33 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
34 | else:
35 | if latents.shape != shape:
36 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
37 | latents = latents.to(device)
38 |
39 | return latents
40 |
41 | @torch.no_grad()
42 | def __call__(
43 | self,
44 | cond_sequence: torch.FloatTensor,
45 | negative_cond_sequence: torch.FloatTensor,
46 | num_images_per_prompt: int = 1,
47 | num_inference_steps: int = 25,
48 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
49 | latents: Optional[torch.FloatTensor] = None,
50 | init_latents: Optional[torch.FloatTensor] = None,
51 | strength: Optional[float] = None,
52 | guidance_scale: float = 1.0,
53 | output_type: Optional[str] = "pt", # pt only
54 | return_dict: bool = True,
55 | ):
56 |
57 | do_classifier_free_guidance = guidance_scale > 1.0
58 |
59 | device = self._execution_device
60 |
61 | batch_size = cond_sequence.shape[0]
62 | batch_size = batch_size * num_images_per_prompt
63 |
64 | embedding_dim = self.prior.config.embedding_dim
65 |
66 | latents = self.prepare_latents(
67 | (batch_size, 16, embedding_dim),
68 | self.prior.dtype,
69 | device,
70 | generator,
71 | latents,
72 | )
73 |
74 | if init_latents is not None:
75 | init_latents = init_latents.to(latents.device)
76 | latents = (strength) * latents + (1 - strength) * init_latents
77 |
78 | # Rectified Flow
79 | dt = 1.0 / num_inference_steps
80 | dt = torch.tensor([dt] * batch_size).to(latents.device).view([batch_size, *([1] * len(latents.shape[1:]))])
81 | start_inference_step = (
82 | math.ceil(num_inference_steps * (strength)) if strength is not None else num_inference_steps
83 | )
84 | for i in range(start_inference_step, 0, -1):
85 | t = i / num_inference_steps
86 | t = torch.tensor([t] * batch_size).to(latents.device)
87 |
88 | vc = self.prior(latents, t, cond_sequence)
89 | if do_classifier_free_guidance:
90 | vu = self.prior(latents, t, negative_cond_sequence)
91 | vc = vu + guidance_scale * (vc - vu)
92 |
93 | latents = latents - dt * vc
94 |
95 | image_embeddings = latents
96 |
97 | if output_type not in ["pt", "np"]:
98 | raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
99 |
100 | if output_type == "np":
101 | image_embeddings = image_embeddings.cpu().numpy()
102 |
103 | if not return_dict:
104 | return image_embeddings
105 |
106 | return PiTPipelineOutput(image_embeds=image_embeddings)
107 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "PiT"
3 | version = "1.0"
4 | requires-python = ">=3.12"
5 | dependencies = [
6 | "accelerate>=1.2.1",
7 | "diffusers>=0.32.1",
8 | "einops>=0.8.0",
9 | "kornia>=0.8.0",
10 | "matplotlib>=3.10.0",
11 | "opencv-python>=4.10.0.84",
12 | "pandas>=2.2.3",
13 | "peft>=0.14.0",
14 | "protobuf>=5.29.2",
15 | "pyrallis>=0.3.1",
16 | "scikit-learn>=1.6.1",
17 | "scipy>=1.15.0",
18 | "sentencepiece>=0.2.0",
19 | "supervision>=0.25.1",
20 | "tensorboard>=2.18.0",
21 | "timm>=1.0.12",
22 | "torch>=2.5.1",
23 | "torchvision>=0.20.0",
24 | "transformers>=4.47.1",
25 | "wandb>=0.19.1",
26 | ]
27 |
28 |
29 | [tool.black]
30 | line-length = 120
--------------------------------------------------------------------------------
/scripts/generate_characters.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import einops
3 | import math
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import pyrallis
7 | import random
8 | import supervision as sv
9 | import torch
10 | from PIL import Image
11 | from utils import words_bank
12 | from dataclasses import dataclass
13 | from diffusers import StableDiffusionXLPipeline, FluxPipeline
14 | from pathlib import Path
15 | from scipy import ndimage
16 | from transformers import SamModel, SamProcessor, AutoProcessor, AutoModelForCausalLM
17 | from transformers import pipeline
18 | from typing import Tuple, Dict
19 |
20 |
21 | @dataclass
22 | class RunConfig:
23 | # Generation mode, should be either 'objects' or 'scenes'
24 | out_dir: Path = Path("datasets/generated/monsters/")
25 | n_images: int = 1000000
26 | vis_data: bool = False
27 | n_samples_in_dir: int = 1000
28 |
29 |
30 | def crop_from_mask(image, mask: np.ndarray):
31 | # Apply mask and crop a tight box
32 | mask = mask.astype(np.uint8)
33 | mask = mask * 255
34 | mask = Image.fromarray(mask)
35 | bbox = mask.getbbox()
36 |
37 | # Create a new image with a white background
38 | white_background = Image.new("RGB", image.size, (255, 255, 255))
39 |
40 | # Apply the mask to the original image
41 | masked_image = Image.composite(image, white_background, mask)
42 |
43 | # Crop the image to the bounding box
44 | cropped_image = masked_image.crop(bbox)
45 |
46 | return cropped_image
47 |
48 |
49 | def show_mask(mask, ax, random_color=False):
50 | if random_color:
51 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
52 | else:
53 | color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
54 | h, w = mask.shape[-2:]
55 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
56 | ax.imshow(mask_image)
57 |
58 |
59 | def show_masks_on_image(raw_image, masks, caption=None):
60 | plt.imshow(np.array(raw_image))
61 | ax = plt.gca()
62 | ax.set_autoscale_on(False)
63 | for mask in masks:
64 | show_mask(mask, ax=ax, random_color=True)
65 | plt.axis("off")
66 | if caption:
67 | plt.title(caption)
68 | plt.show()
69 |
70 |
71 | def show_box(box, ax):
72 | x0, y0 = box[0], box[1]
73 | w, h = box[2] - box[0], box[3] - box[1]
74 | facecolor = [random.random(), random.random(), random.random(), 0.3]
75 | edgecolor = facecolor.copy()
76 | edgecolor[3] = 1
77 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=facecolor, lw=2))
78 |
79 |
80 | def show_boxes_on_image(raw_image, boxes):
81 | plt.figure(figsize=(10, 10))
82 | plt.imshow(raw_image)
83 | for box in boxes:
84 | show_box(box, plt.gca())
85 | plt.axis("on")
86 | plt.show()
87 |
88 |
89 | @pyrallis.wrap()
90 | def generate(cfg: RunConfig):
91 | cfg.out_dir.mkdir(exist_ok=True, parents=True)
92 |
93 | flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")
94 |
95 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96 | model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
97 | processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
98 |
99 | checkpoint = "microsoft/Florence-2-large"
100 | florence_processor = AutoProcessor.from_pretrained(checkpoint, trust_remote_code=True)
101 | florence_model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True).to(device).eval()
102 |
103 | def run_florence_inference(
104 | image: Image, task: str = "", text: str = ""
105 | ) -> Tuple[str, Dict]:
106 | prompt = task + text
107 | inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
108 | generated_ids = florence_model.generate(
109 | input_ids=inputs["input_ids"],
110 | pixel_values=inputs["pixel_values"],
111 | max_new_tokens=1024,
112 | num_beams=3,
113 | output_scores=True,
114 | )
115 | generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
116 | response = florence_processor.post_process_generation(generated_text, task=task, image_size=image.size)
117 |
118 | detections = sv.Detections.from_lmm(lmm=sv.LMM.FLORENCE_2, result=response, resolution_wh=image.size)
119 | input_boxes = detections.xyxy
120 | return input_boxes
121 |
122 | with open("assets/openimages_classes.txt", "r") as f:
123 | objects = f.read().splitlines()
124 | objects = ["".join(char if char.isalnum() else " " for char in object_name) for object_name in objects]
125 | # Duplicate creatures to match the same size as objects
126 | creatures = words_bank.creatures * (len(objects) // len(words_bank.creatures) + 1) + objects * 10
127 |
128 | tot_generated = 0
129 | for _ in range(cfg.n_images):
130 | try:
131 | new_dir_name = f"set_{tot_generated}_{random.randint(0, 1000000)}"
132 | out_dir = cfg.out_dir / new_dir_name
133 | out_dir.mkdir(exist_ok=True, parents=True)
134 | monster_grid = []
135 |
136 | for _ in range(cfg.n_samples_in_dir):
137 |
138 | adjective_count = random.randint(2, 6)
139 | adjectives = random.sample(words_bank.adjectives, adjective_count)
140 | if len(adjectives) > 0:
141 | adjectives_txt = " ".join(adjectives) + " "
142 |
143 | character_count = random.randint(1, 3)
144 | characters = random.sample(creatures, character_count)
145 | characters = [f"{c}-like" for c in characters]
146 | character_txt = " ".join(characters)
147 |
148 | prompt = f"studio photo pixar style concept art, An imaginary fantasy {adjectives_txt} {character_txt} creature with eyes arms legs mouth , white background studio photo pixar style asset"
149 | seed = random.randint(0, 1000000)
150 |
151 | print(prompt)
152 | base_image = flux_pipe(
153 | prompt,
154 | guidance_scale=0.0,
155 | num_inference_steps=4,
156 | max_sequence_length=256,
157 | ).images[0]
158 |
159 | input_boxes = []
160 | keywords = words_bank.keywords
161 | for keyword in keywords:
162 | current_boxes = list(run_florence_inference(base_image, text=keyword))
163 | # Randomly choose one
164 | if len(current_boxes) > 0:
165 | input_boxes.extend(random.sample(current_boxes, 1))
166 |
167 | # convert to ints
168 | input_boxes = [[int(x) for x in box] for box in input_boxes]
169 |
170 | inputs = processor(base_image, input_boxes=[input_boxes], return_tensors="pt").to(device)
171 | with torch.no_grad():
172 | outputs = model(**inputs)
173 | masks = processor.image_processor.post_process_masks(
174 | outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
175 | )[0]
176 | #
177 | masks = [mask[0].cpu().detach().numpy() for mask in masks]
178 |
179 | masks = sorted(masks, key=lambda mask: mask.sum(), reverse=False)
180 |
181 | # Filter the masks
182 | masks = [mask for mask in masks if 0.015 < mask.sum() / mask.flatten().shape[0] < 0.3]
183 |
184 | for m_ind in range(len(masks)):
185 | # Apply dilate and erode to the mask
186 | mask = masks[m_ind].astype(np.uint8)
187 | mask = cv2.dilate(mask, np.ones((15, 15), np.uint8), iterations=1)
188 | mask = cv2.erode(mask, np.ones((15, 15), np.uint8), iterations=1)
189 | # Now do the reverse to get rid of the small holes
190 | mask = cv2.erode(mask, np.ones((15, 15), np.uint8), iterations=1)
191 | mask = cv2.dilate(mask, np.ones((15, 15), np.uint8), iterations=1)
192 | if True or random.random() < 0.5:
193 | # Close mask
194 | mask = ndimage.binary_fill_holes(mask.astype(int))
195 | masks[m_ind] = mask == 1
196 |
197 | masks = [mask for mask in masks if 0.015 < mask.sum() / mask.flatten().shape[0] < 0.3]
198 |
199 | if cfg.vis_data:
200 | plt.imshow(base_image)
201 | plt.show()
202 | show_masks_on_image(base_image, masks)
203 |
204 | visited_area = np.zeros_like(masks[0])
205 | prompt_hash = str(abs(hash(prompt)))
206 |
207 | for i, mask in enumerate(masks):
208 | # Check if overlaps with visited_area
209 | if (visited_area * mask).sum() > 0:
210 | continue
211 | visited_area += mask
212 |
213 | cropped_image = crop_from_mask(base_image, mask)
214 | out_path = out_dir / f"{prompt_hash}_{seed}_{i}.jpg"
215 | cropped_image.save(out_path)
216 | if cfg.vis_data:
217 | plt.imshow(cropped_image)
218 | plt.show()
219 |
220 | out_path = out_dir / f"{prompt_hash}_{seed}.jpg"
221 | tot_generated += 1
222 | base_image.save(out_path)
223 | monster_grid.append(base_image)
224 |
225 | except Exception as e:
226 | print(e)
227 |
228 |
229 | if __name__ == "__main__":
230 | # Use to generate objects or backgrounds
231 | generate()
232 |
--------------------------------------------------------------------------------
/scripts/generate_products.py:
--------------------------------------------------------------------------------
1 | import random
2 | from pathlib import Path
3 |
4 | import random
5 | from dataclasses import dataclass
6 | from pathlib import Path
7 |
8 | import cv2
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import pyrallis
12 | import torch
13 | from PIL import Image
14 | from diffusers import FluxPipeline
15 | from scipy import ndimage
16 | from transformers import pipeline
17 |
18 | from utils import words_bank
19 |
20 |
21 | @dataclass
22 | class RunConfig:
23 | out_dir: Path = Path("datasets/generated/products")
24 | n_images: int = 1000000
25 | vis_data: bool = False
26 | n_samples_in_dir: int = 1000
27 |
28 |
29 | def crop_from_mask(image, mask: np.ndarray):
30 | # Apply mask and crop a tight box
31 | mask = mask.astype(np.uint8)
32 | mask = mask * 255
33 | mask = Image.fromarray(mask)
34 | bbox = mask.getbbox()
35 |
36 | # Create a new image with a white background
37 | white_background = Image.new("RGB", image.size, (255, 255, 255))
38 |
39 | # Apply the mask to the original image
40 | masked_image = Image.composite(image, white_background, mask)
41 |
42 | # Crop the image to the bounding box
43 | cropped_image = masked_image.crop(bbox)
44 |
45 | return cropped_image
46 |
47 |
48 | def show_mask(mask, ax, random_color=False):
49 | if random_color:
50 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
51 | else:
52 | color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
53 | h, w = mask.shape[-2:]
54 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
55 | ax.imshow(mask_image)
56 |
57 |
58 | def show_masks_on_image(raw_image, masks, caption=None):
59 | plt.imshow(np.array(raw_image))
60 | ax = plt.gca()
61 | ax.set_autoscale_on(False)
62 | for mask in masks:
63 | show_mask(mask, ax=ax, random_color=True)
64 | plt.axis("off")
65 | if caption:
66 | plt.title(caption)
67 | plt.show()
68 |
69 |
70 | @pyrallis.wrap()
71 | def generate(cfg: RunConfig):
72 | cfg.out_dir.mkdir(exist_ok=True, parents=True)
73 |
74 | flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")
75 |
76 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77 | segmentor = pipeline("mask-generation", model="facebook/sam-vit-base", device=device)
78 |
79 | with open("assets/openimages_classes.txt", "r") as f:
80 | objects = f.read().splitlines()
81 | objects = ["".join(char if char.isalnum() else " " for char in object_name) for object_name in objects]
82 |
83 | tot_generated = 0
84 | for _ in range(cfg.n_images):
85 | new_dir_name = f"set_{tot_generated}_{random.randint(0, 1000000)}"
86 | out_dir = cfg.out_dir / new_dir_name
87 | out_dir.mkdir(exist_ok=True, parents=True)
88 | monster_grid = []
89 |
90 | for _ in range(1000):
91 | try:
92 | character_count = random.randint(0, 3)
93 | if character_count == 0:
94 | character_txt = ""
95 | else:
96 | characters = random.sample(objects, character_count)
97 | # For each character text take only one random word
98 | characters = [random.choice(character.split()) for character in characters]
99 | characters = [f"{c}-like" for c in characters]
100 | character_txt = " ".join(characters)
101 |
102 | attributes_count = random.randint(1, 3)
103 | material_count = random.randint(1, 2)
104 | attributes = random.sample(words_bank.object_attributes, attributes_count)
105 | materials = random.sample(words_bank.product_materials, material_count)
106 | features = random.sample(words_bank.product_defining_attributes, 1)
107 | attributes_and_materials_txt = " ".join(attributes + materials + features)
108 |
109 | prompt = f"A product design photo of a {attributes_and_materials_txt}product with {character_txt} attributes, integrated together to create one seamless product. It is set against a light gray background with a soft gradient, creating a neutral and elegant backdrop that emphasizes the contemporary design. The soft, even lighting highlights the contours and textures, lending a professional, polished quality to the composition"
110 | seed = random.randint(0, 1000000)
111 |
112 | print(prompt)
113 | base_image = flux_pipe(
114 | prompt,
115 | guidance_scale=0.0,
116 | num_inference_steps=4,
117 | max_sequence_length=256,
118 | ).images[0]
119 |
120 | if cfg.vis_data:
121 | plt.imshow(base_image)
122 | plt.title(f"{attributes_and_materials_txt} {character_txt}")
123 | plt.show()
124 | # continue
125 | all_masks = segmentor(base_image, points_per_batch=64)["masks"]
126 |
127 | if len(all_masks) == 0:
128 | continue
129 |
130 | # Sort by area
131 | all_masks = sorted(all_masks, key=lambda mask: mask.sum(), reverse=False)
132 | # Remove the last item
133 | masks = all_masks[:-1]
134 |
135 | if len(all_masks) < 3:
136 | # For now take only things with at least 3 parts to keep the data interesting
137 | continue
138 |
139 | # Remove masks that intersect with image boundary
140 | mask_boundary = np.zeros_like(masks[0])
141 | mask_boundary[0, :] = 1
142 | mask_boundary[-1, :] = 1
143 | mask_boundary[:, 0] = 1
144 | mask_boundary[:, -1] = 1
145 |
146 | masks = [mask for mask in masks if (mask * mask_boundary).sum() == 0]
147 |
148 | masks = [mask for mask in masks if 0.015 < mask.sum() / mask.flatten().shape[0] < 0.3]
149 |
150 | for m_ind in range(len(masks)):
151 | # Apply dilate and erode to the mask
152 | mask = masks[m_ind].astype(np.uint8)
153 | mask = cv2.dilate(mask, np.ones((15, 15), np.uint8), iterations=1)
154 | mask = cv2.erode(mask, np.ones((15, 15), np.uint8), iterations=1)
155 | # Now do the reverse to get rid of the small holes
156 | mask = cv2.erode(mask, np.ones((15, 15), np.uint8), iterations=1)
157 | mask = cv2.dilate(mask, np.ones((15, 15), np.uint8), iterations=1)
158 | if True or random.random() < 0.5:
159 | # Close mask
160 | mask = ndimage.binary_fill_holes(mask.astype(int))
161 | masks[m_ind] = mask == 1
162 |
163 | masks = [mask for mask in masks if 0.015 < mask.sum() / mask.flatten().shape[0] < 0.3]
164 |
165 | if len(masks) == 0:
166 | print(f"No masks found for {character_txt}")
167 | continue
168 |
169 | # Restrict to 8
170 | masks = masks[:8]
171 |
172 | if cfg.vis_data:
173 | show_masks_on_image(base_image, masks)
174 |
175 | visited_area = np.zeros_like(masks[0])
176 | prompt_hash = str(abs(hash(prompt)))
177 |
178 | for i, mask in enumerate(masks):
179 | # Check if overlaps with visited_area
180 | if (visited_area * mask).sum() > 0:
181 | continue
182 | visited_area += mask
183 |
184 | cropped_image = crop_from_mask(base_image, mask)
185 | cropped_image.thumbnail((256, 256))
186 | out_path = out_dir / f"{prompt_hash}_{seed}_{i}.jpg"
187 | cropped_image.save(out_path)
188 | if cfg.vis_data:
189 | plt.imshow(cropped_image)
190 | plt.show()
191 |
192 | out_path = out_dir / f"{prompt_hash}_{seed}.jpg"
193 | base_image.thumbnail((512, 512))
194 | tot_generated += 1
195 | base_image.save(out_path)
196 | if len(monster_grid) < 9:
197 | monster_grid.append(base_image)
198 | except Exception as e:
199 | print(e)
200 | continue
201 |
202 |
203 | if __name__ == "__main__":
204 | # Use to generate objects or backgrounds
205 | generate()
206 |
--------------------------------------------------------------------------------
/scripts/infer.py:
--------------------------------------------------------------------------------
1 | import random
2 | from dataclasses import dataclass
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | import numpy as np
7 | import pyrallis
8 | import torch
9 | from PIL import Image
10 | from diffusers import (
11 | StableDiffusionXLPipeline,
12 | )
13 | from huggingface_hub import hf_hub_download
14 | from tqdm import tqdm
15 |
16 | from ip_adapter import IPAdapterPlusXL
17 | from model.dit import DiT_Llama
18 | from model.pipeline_pit import PiTPipeline
19 | from training.train_config import TrainConfig
20 | from utils import vis_utils, bezier_utils
21 |
22 |
23 | def paste_on_background(image, background, min_scale=0.4, max_scale=0.8, scale=None):
24 | # Calculate aspect ratio and determine resizing based on the smaller dimension of the background
25 | aspect_ratio = image.width / image.height
26 | scale = random.uniform(min_scale, max_scale) if scale is None else scale
27 | new_width = int(min(background.width, background.height * aspect_ratio) * scale)
28 | new_height = int(new_width / aspect_ratio)
29 |
30 | # Resize image and calculate position
31 | image = image.resize((new_width, new_height), resample=Image.LANCZOS)
32 | pos_x = random.randint(0, background.width - new_width)
33 | pos_y = random.randint(0, background.height - new_height)
34 |
35 | # Paste the image using its alpha channel as mask if present
36 | background.paste(image, (pos_x, pos_y), image if "A" in image.mode else None)
37 | return background
38 |
39 |
40 | def set_seed(seed: int):
41 | """Ensures reproducibility across multiple libraries."""
42 | random.seed(seed) # Python random module
43 | np.random.seed(seed) # NumPy random module
44 | torch.manual_seed(seed) # PyTorch CPU random seed
45 | torch.cuda.manual_seed_all(seed) # PyTorch GPU random seed
46 | torch.backends.cudnn.deterministic = True # Ensures deterministic behavior
47 | torch.backends.cudnn.benchmark = False # Disable benchmarking to avoid randomness
48 |
49 |
50 | # Inside main():
51 |
52 |
53 | @dataclass
54 | class RunConfig:
55 | prior_path: Path
56 | crops_dir: Path
57 | output_dir: Path
58 | prior_repo: Optional[str] = None
59 | prior_guidance_scale: float = 1.0
60 | drop_cond: bool = True
61 | n_randoms: int = 10
62 | as_sketch: bool = False
63 | scale: float = 2.0
64 | use_empty_ref: bool = True
65 |
66 |
67 | @pyrallis.wrap()
68 | def main(cfg: RunConfig):
69 | output_dir = cfg.output_dir
70 | output_dir.mkdir(parents=True, exist_ok=True)
71 |
72 | # Download model and config
73 | if cfg.prior_repo is not None:
74 | prior_ckpt_path = hf_hub_download(
75 | repo_id=cfg.prior_repo,
76 | filename=str(cfg.prior_path),
77 | local_dir="pretrained_models",
78 | )
79 | prior_cfg_path = hf_hub_download(
80 | repo_id=cfg.prior_repo, filename=str(cfg.prior_path.parent / "cfg.yaml"), local_dir="pretrained_models"
81 | )
82 | else:
83 | prior_ckpt_path = cfg.prior_path
84 | prior_cfg_path = cfg.prior_path.parent / "cfg.yaml"
85 |
86 | # Load model_cfg from file
87 | model_cfg: TrainConfig = pyrallis.load(TrainConfig, open(prior_cfg_path, "r"))
88 |
89 | weight_dtype = torch.float32
90 | device = "cuda:0"
91 | prior = DiT_Llama(
92 | embedding_dim=2048,
93 | hidden_dim=model_cfg.hidden_dim,
94 | n_layers=model_cfg.num_layers,
95 | n_heads=model_cfg.num_attention_heads,
96 | )
97 | prior.load_state_dict(torch.load(prior_ckpt_path))
98 |
99 | image_pipe = StableDiffusionXLPipeline.from_pretrained(
100 | "stabilityai/stable-diffusion-xl-base-1.0",
101 | torch_dtype=torch.float16,
102 | add_watermarker=False,
103 | )
104 |
105 | ip_ckpt_path = hf_hub_download(
106 | repo_id="h94/IP-Adapter",
107 | filename="ip-adapter-plus_sdxl_vit-h.bin",
108 | subfolder="sdxl_models",
109 | local_dir="pretrained_models",
110 | )
111 |
112 | ip_model = IPAdapterPlusXL(
113 | image_pipe,
114 | "models/image_encoder",
115 | ip_ckpt_path,
116 | device,
117 | num_tokens=16,
118 | )
119 |
120 | image_processor = ip_model.clip_image_processor
121 |
122 | empty_image = Image.new("RGB", (256, 256), (255, 255, 255))
123 | zero_image = torch.Tensor(image_processor(empty_image)["pixel_values"][0])
124 | zero_image_embeds = ip_model.get_image_embeds(zero_image.unsqueeze(0), skip_uncond=True)
125 |
126 | prior_pipeline = PiTPipeline(
127 | prior=prior,
128 | )
129 | prior_pipeline = prior_pipeline.to(device)
130 |
131 | set_seed(42)
132 |
133 | # Read all crops from the dir
134 | crop_sets = []
135 | unordered_crops = []
136 | for crop_dir_path in cfg.crops_dir.iterdir():
137 | unordered_crops.append(crop_dir_path)
138 | unordered_crops = sorted(unordered_crops, key=lambda x: x.stem)
139 |
140 | if len(unordered_crops) > 0:
141 | for _ in range(cfg.n_randoms):
142 | n_crops = random.randint(1, min(3, len(unordered_crops)))
143 | crop_paths = random.sample(unordered_crops, n_crops)
144 | # Some of the paths might be dirs, if it is a dir, take a random file from it
145 | crop_paths = [c if c.is_file() else random.choice([f for f in c.iterdir()]) for c in crop_paths]
146 | crop_sets.append(crop_paths)
147 |
148 | if model_cfg.use_ref:
149 | if cfg.use_empty_ref:
150 | print(f"----- USING EMPTY GRIDS -----")
151 | augmented_crop_sets = [[None] + crop_set for crop_set in crop_sets]
152 | else:
153 | print(f"----- USING REFERENCE GRIDS -----")
154 | augmented_crop_sets = []
155 | refs_dir = Path("assets/ref_grids")
156 | refs = [f for f in refs_dir.iterdir()]
157 | for crop_set in crop_sets:
158 | # Choose a subset of refs
159 | chosen_refs = random.sample(refs, 1) # [None] # + random.sample(refs, 5)
160 | for ref in chosen_refs:
161 | augmented_crop_sets.append([ref] + crop_set)
162 |
163 | crop_sets = augmented_crop_sets
164 |
165 | random.shuffle(crop_sets)
166 |
167 | for crop_paths in tqdm(crop_sets):
168 | out_name = f"{random.randint(0, 1000000)}"
169 |
170 | processed_crops = []
171 | input_images = []
172 | captions = []
173 |
174 | # Extend to >3 with Nones
175 | while len(crop_paths) < 3:
176 | crop_paths.append(None)
177 |
178 | for path_ind, path in enumerate(crop_paths):
179 | if path is None:
180 | image = Image.new("RGB", (224, 224), (255, 255, 255))
181 | else:
182 | image = Image.open(path).convert("RGB")
183 | if path_ind > 0 or not model_cfg.use_ref:
184 | background = Image.new("RGB", (1024, 1024), (255, 255, 255))
185 | image = paste_on_background(image, background, scale=0.92)
186 | else:
187 | image = image.resize((1024, 1024))
188 | if cfg.as_sketch and random.random() < 0.5:
189 | num_lines = random.randint(8, 15)
190 | image = bezier_utils.get_sketch(image, total_curves=num_lines, drop_line_prob=0.1)
191 | input_images.append(image)
192 | # Name should be parent directory name
193 | captions.append(path.parent.stem)
194 | processed_image = (
195 | torch.Tensor(image_processor(image)["pixel_values"][0]).to(device).unsqueeze(0).to(weight_dtype)
196 | )
197 | processed_crops.append(processed_image)
198 |
199 | image_embed_inputs = []
200 | for crop_ind in range(len(processed_crops)):
201 | image_embed_inputs.append(ip_model.get_image_embeds(processed_crops[crop_ind], skip_uncond=True))
202 | crops_input_sequence = torch.cat(image_embed_inputs, dim=1)
203 |
204 | for _ in range(4):
205 | seed = random.randint(0, 1000000)
206 | for scale in [cfg.scale]:
207 | negative_cond_sequence = torch.zeros_like(crops_input_sequence)
208 | embeds_len = zero_image_embeds.shape[1]
209 | for i in range(0, negative_cond_sequence.shape[1], embeds_len):
210 | negative_cond_sequence[:, i : i + embeds_len] = zero_image_embeds.detach()
211 |
212 | img_emb = prior_pipeline(
213 | cond_sequence=crops_input_sequence,
214 | negative_cond_sequence=negative_cond_sequence,
215 | num_inference_steps=25,
216 | num_images_per_prompt=1,
217 | guidance_scale=scale,
218 | generator=torch.Generator(device="cuda").manual_seed(seed),
219 | ).image_embeds
220 |
221 | for seed_2 in range(1):
222 | images = ip_model.generate(
223 | image_prompt_embeds=img_emb,
224 | num_samples=1,
225 | num_inference_steps=50,
226 | )
227 | input_images += images
228 | captions.append(f"prior_s {seed}, cfg {scale}") # , unet_s {seed_2}")
229 | # The rest of the results will just be in the dir
230 | gen_images = vis_utils.create_table_plot(images=input_images, captions=captions)
231 |
232 | gen_images.save(output_dir / f"{out_name}.jpg")
233 |
234 | # Also save the divided images in a separate folder whose name is the same as the output image
235 | divided_images_dir = output_dir / f"{out_name}_divided"
236 | divided_images_dir.mkdir(parents=True, exist_ok=True)
237 | for i, img in enumerate(input_images):
238 | img.save(divided_images_dir / f"{i}.jpg")
239 | print("Done!")
240 |
241 |
242 | if __name__ == "__main__":
243 | main()
244 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import pyrallis
2 |
3 | from training.coach import Coach
4 | from training.train_config import TrainConfig
5 |
6 |
7 | @pyrallis.wrap()
8 | def main(cfg: TrainConfig):
9 | coach = Coach(cfg)
10 | coach.train()
11 |
12 |
13 | if __name__ == "__main__":
14 | main()
15 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/training/__init__.py
--------------------------------------------------------------------------------
/training/coach.py:
--------------------------------------------------------------------------------
1 | import random
2 | import sys
3 | from pathlib import Path
4 |
5 | import diffusers
6 | import pyrallis
7 | import torch
8 | import torch.utils.checkpoint
9 | import transformers
10 | from PIL import Image
11 | from accelerate import Accelerator
12 | from accelerate.logging import get_logger
13 | from accelerate.utils import ProjectConfiguration, set_seed
14 | from diffusers import StableDiffusionXLPipeline
15 | from huggingface_hub import hf_hub_download
16 | from torchvision import transforms
17 | from tqdm import tqdm
18 |
19 | from ip_adapter import IPAdapterPlusXL
20 | from model.dit import DiT_Llama
21 | from model.pipeline_pit import PiTPipeline
22 | from training.dataset import (
23 | PartsDataset,
24 | )
25 | from training.train_config import TrainConfig
26 | from utils import vis_utils
27 |
28 | logger = get_logger(__name__, log_level="INFO")
29 |
30 |
31 | class Coach:
32 | def __init__(self, config: TrainConfig):
33 | self.cfg = config
34 | self.cfg.output_dir.mkdir(exist_ok=True, parents=True)
35 | (self.cfg.output_dir / "cfg.yaml").write_text(pyrallis.dump(self.cfg))
36 | (self.cfg.output_dir / "run.sh").write_text(f'python {Path(__file__).name} {" ".join(sys.argv)}')
37 |
38 | self.logging_dir = self.cfg.output_dir / "logs"
39 | accelerator_project_config = ProjectConfiguration(
40 | total_limit=2, project_dir=self.cfg.output_dir, logging_dir=self.logging_dir
41 | )
42 | self.accelerator = Accelerator(
43 | gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
44 | mixed_precision=self.cfg.mixed_precision,
45 | log_with=self.cfg.report_to,
46 | project_config=accelerator_project_config,
47 | )
48 |
49 | self.device = "cuda"
50 |
51 | logger.info(self.accelerator.state, main_process_only=False)
52 | if self.accelerator.is_local_main_process:
53 | transformers.utils.logging.set_verbosity_warning()
54 | diffusers.utils.logging.set_verbosity_info()
55 | else:
56 | transformers.utils.logging.set_verbosity_error()
57 | diffusers.utils.logging.set_verbosity_error()
58 |
59 | if self.cfg.seed is not None:
60 | set_seed(self.cfg.seed)
61 |
62 | if self.accelerator.is_main_process:
63 | self.logging_dir.mkdir(exist_ok=True, parents=True)
64 |
65 | self.weight_dtype = torch.float32
66 | if self.accelerator.mixed_precision == "fp16":
67 | self.weight_dtype = torch.float16
68 | elif self.accelerator.mixed_precision == "bf16":
69 | self.weight_dtype = torch.bfloat16
70 |
71 | self.prior = DiT_Llama(
72 | embedding_dim=2048,
73 | hidden_dim=self.cfg.hidden_dim,
74 | n_layers=self.cfg.num_layers,
75 | n_heads=self.cfg.num_attention_heads,
76 | )
77 | # pretty print total number of parameters in Billions
78 | num_params = sum(p.numel() for p in self.prior.parameters())
79 | print(f"Number of parameters: {num_params / 1e9:.2f}B")
80 |
81 | self.image_pipe = StableDiffusionXLPipeline.from_pretrained(
82 | "stabilityai/stable-diffusion-xl-base-1.0",
83 | torch_dtype=torch.float16,
84 | add_watermarker=False,
85 | ).to(self.device)
86 |
87 | ip_ckpt_path = hf_hub_download(
88 | repo_id="h94/IP-Adapter",
89 | filename="ip-adapter-plus_sdxl_vit-h.bin",
90 | subfolder="sdxl_models",
91 | local_dir="pretrained_models",
92 | )
93 |
94 | self.ip_model = IPAdapterPlusXL(
95 | self.image_pipe,
96 | "models/image_encoder",
97 | ip_ckpt_path,
98 | self.device,
99 | num_tokens=16,
100 | )
101 |
102 | self.image_processor = self.ip_model.clip_image_processor
103 |
104 | empty_image = Image.new("RGB", (256, 256), (255, 255, 255))
105 | zero_image = torch.Tensor(self.image_processor(empty_image)["pixel_values"][0])
106 | self.zero_image_embeds = self.ip_model.get_image_embeds(zero_image.unsqueeze(0), skip_uncond=True)
107 |
108 | self.prior_pipeline = PiTPipeline(prior=self.prior)
109 | self.prior_pipeline = self.prior_pipeline.to(self.accelerator.device)
110 |
111 | params_to_optimize = list(self.prior.parameters())
112 |
113 | self.optimizer = torch.optim.AdamW(
114 | params_to_optimize,
115 | lr=self.cfg.lr,
116 | betas=(self.cfg.adam_beta1, self.cfg.adam_beta2),
117 | weight_decay=self.cfg.adam_weight_decay,
118 | eps=self.cfg.adam_epsilon,
119 | )
120 |
121 | self.train_dataloader, self.validation_dataloader = self.get_dataloaders()
122 |
123 | self.prior, self.optimizer, self.train_dataloader = self.accelerator.prepare(
124 | self.prior, self.optimizer, self.train_dataloader
125 | )
126 |
127 | self.train_step = 0 if self.cfg.resume_from_step is None else self.cfg.resume_from_step
128 | print(self.train_step)
129 |
130 | if self.cfg.resume_from_path is not None:
131 | prior_state_dict = torch.load(self.cfg.resume_from_path, map_location=self.device)
132 | msg = self.prior.load_state_dict(prior_state_dict, strict=False)
133 | print(msg)
134 |
135 | def save_model(self, save_path):
136 | save_path.mkdir(exist_ok=True, parents=True)
137 | prior_state_dict = self.prior.state_dict()
138 | torch.save(prior_state_dict, save_path / "prior.ckpt")
139 |
140 | def unnormalize_and_pil(self, tensor):
141 | unnormed = tensor * torch.tensor(self.image_processor.image_std).view(3, 1, 1).to(tensor.device) + torch.tensor(
142 | self.image_processor.image_mean
143 | ).view(3, 1, 1).to(tensor.device)
144 | return transforms.ToPILImage()(unnormed)
145 |
146 | def save_images(self, image, conds, cond_sequence, target_embeds, label="", save_path=""):
147 | self.prior.eval()
148 | input_images = []
149 | captions = []
150 | for i in range(len(conds)):
151 | pil_image = self.unnormalize_and_pil(conds[i]).resize((self.cfg.img_size, self.cfg.img_size))
152 | input_images.append(pil_image)
153 | captions.append("Condition")
154 | if image is not None:
155 | input_images.append(self.unnormalize_and_pil(image).resize((self.cfg.img_size, self.cfg.img_size)))
156 | captions.append(f"Target {label}")
157 |
158 | seeds = range(2)
159 | output_images = []
160 | embebds_to_vis = []
161 | embeds_captions = []
162 | embebds_to_vis += [target_embeds]
163 | embeds_captions += ["Target Reconstruct" if image is not None else "Source Reconstruct"]
164 | if self.cfg.use_ref:
165 | embebds_to_vis += [cond_sequence[:, :16]]
166 | embeds_captions += ["Grid Reconstruct"]
167 | for embs in embebds_to_vis:
168 | direct_from_emb = self.ip_model.generate(image_prompt_embeds=embs, num_samples=1, num_inference_steps=50)
169 | output_images = output_images + direct_from_emb
170 | captions += embeds_captions
171 |
172 | for seed in seeds:
173 | for scale in [1, 4]:
174 | negative_cond_sequence = torch.zeros_like(cond_sequence)
175 | embeds_len = self.zero_image_embeds.shape[1]
176 | for i in range(0, negative_cond_sequence.shape[1], embeds_len):
177 | negative_cond_sequence[:, i : i + embeds_len] = self.zero_image_embeds.detach()
178 | img_emb = self.prior_pipeline(
179 | cond_sequence=cond_sequence,
180 | negative_cond_sequence=negative_cond_sequence,
181 | num_inference_steps=25,
182 | num_images_per_prompt=1,
183 | guidance_scale=scale,
184 | generator=torch.Generator(device="cuda").manual_seed(seed),
185 | ).image_embeds
186 |
187 | for seed_2 in range(1):
188 | images = self.ip_model.generate(
189 | image_prompt_embeds=img_emb,
190 | num_samples=1,
191 | num_inference_steps=50,
192 | )
193 | output_images += images
194 | captions.append(f"prior_s {seed}, cfg {scale}, unet_s {seed_2}")
195 |
196 | all_images = input_images + output_images
197 | gen_images = vis_utils.create_table_plot(images=all_images, captions=captions)
198 | gen_images.save(save_path)
199 | self.prior.train()
200 |
201 | def get_dataloaders(self) -> torch.utils.data.DataLoader:
202 | dataset_path = self.cfg.dataset_path
203 | if not isinstance(self.cfg.dataset_path, list):
204 | dataset_path = [self.cfg.dataset_path]
205 | datasets = []
206 | for path in dataset_path:
207 | datasets.append(
208 | PartsDataset(
209 | dataset_dir=path,
210 | image_processor=self.image_processor,
211 | use_ref=self.cfg.use_ref,
212 | max_crops=self.cfg.max_crops,
213 | sketch_prob=self.cfg.sketch_prob,
214 | )
215 | )
216 | dataset = torch.utils.data.ConcatDataset(datasets)
217 | print(f"Total number of samples: {len(dataset)}")
218 | dataset_weights = []
219 | for single_dataset in datasets:
220 | dataset_weights.extend([len(dataset) / len(single_dataset)] * len(single_dataset))
221 | sampler_train = torch.utils.data.WeightedRandomSampler(
222 | weights=dataset_weights, num_samples=len(dataset_weights)
223 | )
224 |
225 | validation_dataset = PartsDataset(
226 | dataset_dir=self.cfg.val_dataset_path,
227 | image_processor=self.image_processor,
228 | use_ref=self.cfg.use_ref,
229 | max_crops=self.cfg.max_crops,
230 | sketch_prob=self.cfg.sketch_prob,
231 | )
232 | train_dataloader = torch.utils.data.DataLoader(
233 | dataset,
234 | batch_size=self.cfg.train_batch_size,
235 | shuffle=sampler_train is None,
236 | num_workers=self.cfg.num_workers,
237 | sampler=sampler_train,
238 | )
239 |
240 | validation_dataloader = torch.utils.data.DataLoader(
241 | validation_dataset,
242 | batch_size=1,
243 | shuffle=True,
244 | num_workers=self.cfg.num_workers,
245 | )
246 | return train_dataloader, validation_dataloader
247 |
248 | def train(self):
249 | pbar = tqdm(range(self.train_step, self.cfg.max_train_steps + 1))
250 | # self.log_validation()
251 |
252 | while self.train_step < self.cfg.max_train_steps:
253 | train_loss = 0.0
254 | self.prior.train()
255 | lossbin = {i: 0 for i in range(10)}
256 | losscnt = {i: 1e-6 for i in range(10)}
257 |
258 | for sample_idx, batch in enumerate(self.train_dataloader):
259 | with self.accelerator.accumulate(self.prior):
260 | image, cond = batch
261 |
262 | image = image.to(self.weight_dtype).to(self.accelerator.device)
263 | if "crops" in cond:
264 | for crop_ind in range(len(cond["crops"])):
265 | cond["crops"][crop_ind] = (
266 | cond["crops"][crop_ind].to(self.weight_dtype).to(self.accelerator.device)
267 | )
268 | for key in cond.keys():
269 | if isinstance(cond[key], torch.Tensor):
270 | cond[key] = cond[key].to(self.accelerator.device)
271 |
272 | with torch.no_grad():
273 | image_embeds = self.ip_model.get_image_embeds(image, skip_uncond=True)
274 |
275 | b = image_embeds.size(0)
276 | nt = torch.randn((b,)).to(image_embeds.device)
277 | t = torch.sigmoid(nt)
278 | texp = t.view([b, *([1] * len(image_embeds.shape[1:]))])
279 | z_1 = torch.randn_like(image_embeds)
280 | noisy_latents = (1 - texp) * image_embeds + texp * z_1
281 |
282 | target = image_embeds
283 |
284 | # At some prob uniformly sample across the entire batch so the model also learns to work with unpadded inputs
285 | if random.random() < 0.5:
286 | crops_to_keep = random.randint(1, len(cond["crops"]))
287 | cond["crops"] = cond["crops"][:crops_to_keep]
288 | cond_crops = cond["crops"]
289 |
290 | image_embed_inputs = []
291 | for crop_ind in range(len(cond_crops)):
292 | image_embed_inputs.append(
293 | self.ip_model.get_image_embeds(cond_crops[crop_ind], skip_uncond=True)
294 | )
295 | input_sequence = torch.cat(image_embed_inputs, dim=1)
296 |
297 | loss = 0
298 | image_feat_seq = input_sequence
299 |
300 | model_pred = self.prior(
301 | noisy_latents,
302 | t=t,
303 | cond=image_feat_seq,
304 | )
305 |
306 | batchwise_prior_loss = ((z_1 - target.float() - model_pred.float()) ** 2).mean(
307 | dim=list(range(1, len(target.shape)))
308 | )
309 | tlist = batchwise_prior_loss.detach().cpu().reshape(-1).tolist()
310 | ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
311 |
312 | # count based on t
313 | for t, l in ttloss:
314 | lossbin[int(t * 10)] += l
315 | losscnt[int(t * 10)] += 1
316 |
317 | loss += batchwise_prior_loss.mean()
318 | # Gather the losses across all processes for logging (if we use distributed training).
319 | avg_loss = self.accelerator.gather(loss.repeat(self.cfg.train_batch_size)).mean()
320 | train_loss += avg_loss.item() / self.cfg.gradient_accumulation_steps
321 |
322 | # Backprop
323 | self.accelerator.backward(loss)
324 | if self.accelerator.sync_gradients:
325 | self.accelerator.clip_grad_norm_(self.prior.parameters(), self.cfg.max_grad_norm)
326 | self.optimizer.step()
327 | self.optimizer.zero_grad()
328 |
329 | # Checks if the accelerator has performed an optimization step behind the scenes
330 | if self.accelerator.sync_gradients:
331 | pbar.update(1)
332 | self.train_step += 1
333 | train_loss = 0.0
334 |
335 | if self.accelerator.is_main_process:
336 |
337 | if self.train_step % self.cfg.checkpointing_steps == 1:
338 | if self.accelerator.is_main_process:
339 | save_path = self.cfg.output_dir # / f"learned_prior.pth"
340 | self.save_model(save_path)
341 | logger.info(f"Saved state to {save_path}")
342 | pbar.set_postfix(**{"loss": loss.cpu().detach().item()})
343 |
344 | if self.cfg.log_image_frequency > 0 and (self.train_step % self.cfg.log_image_frequency == 1):
345 | image_save_path = self.cfg.output_dir / "images" / f"{self.train_step}_step_images.jpg"
346 | image_save_path.parent.mkdir(exist_ok=True, parents=True)
347 | # Apply the full diffusion process
348 | conds_list = []
349 | for crop_ind in range(len(cond["crops"])):
350 | conds_list.append(cond["crops"][crop_ind][0])
351 |
352 | self.save_images(
353 | image=image[0],
354 | conds=conds_list,
355 | cond_sequence=image_feat_seq[:1],
356 | target_embeds=target[:1],
357 | save_path=image_save_path,
358 | )
359 |
360 | if self.cfg.log_validation > 0 and (self.train_step % self.cfg.log_validation == 0):
361 | # Run validation
362 | self.log_validation()
363 |
364 | if self.train_step >= self.cfg.max_train_steps:
365 | break
366 |
367 | self.train_dataloader, self.validation_dataloader = self.get_dataloaders()
368 | pbar.close()
369 |
370 | def log_validation(self):
371 | for sample_idx, batch in tqdm(enumerate(self.validation_dataloader)):
372 | image, cond = batch
373 | image = image.to(self.weight_dtype).to(self.accelerator.device)
374 | if "crops" in cond:
375 | for crop_ind in range(len(cond["crops"])):
376 | cond["crops"][crop_ind] = cond["crops"][crop_ind].to(self.weight_dtype).to(self.accelerator.device)
377 | for key in cond.keys():
378 | if isinstance(cond[key], torch.Tensor):
379 | cond[key] = cond[key].to(self.accelerator.device)
380 |
381 | with torch.no_grad():
382 | target_embeds = self.ip_model.get_image_embeds(image, skip_uncond=True)
383 | crops_to_keep = random.randint(1, len(cond["crops"]))
384 | cond["crops"] = cond["crops"][:crops_to_keep]
385 | cond_crops = cond["crops"]
386 | image_embed_inputs = []
387 | for crop_ind in range(len(cond_crops)):
388 | image_embed_inputs.append(self.ip_model.get_image_embeds(cond_crops[crop_ind], skip_uncond=True))
389 | input_sequence = torch.cat(image_embed_inputs, dim=1)
390 |
391 | image_save_path = self.cfg.output_dir / "val_images" / f"{self.train_step}_step_{sample_idx}_images.jpg"
392 | image_save_path.parent.mkdir(exist_ok=True, parents=True)
393 |
394 | save_target_image = image[0]
395 | conds_list = []
396 | for crop_ind in range(len(cond["crops"])):
397 | conds_list.append(cond["crops"][crop_ind][0])
398 |
399 | # Apply the full diffusion process
400 | self.save_images(
401 | image=save_target_image,
402 | conds=conds_list,
403 | cond_sequence=input_sequence[:1],
404 | target_embeds=target_embeds[:1],
405 | save_path=image_save_path,
406 | )
407 |
408 | if sample_idx == self.cfg.n_val_images:
409 | break
410 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import traceback
3 | from pathlib import Path
4 |
5 | import einops
6 | import numpy as np
7 | import torchvision.transforms as T
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 | from tqdm import tqdm
11 |
12 | from utils import bezier_utils
13 |
14 |
15 | class PartsDataset(Dataset):
16 | def __init__(
17 | self,
18 | dataset_dir: Path,
19 | clip_image_size: int = 224,
20 | image_processor=None,
21 | max_crops=3,
22 | use_ref: bool = True,
23 | ref_as_grid: bool = True,
24 | grid_size: int = 2,
25 | sketch_prob: float = 0.0,
26 | ):
27 | subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
28 |
29 | all_paths = []
30 | self.subdir_dict = {}
31 | for subdir in tqdm(subdirs):
32 | current_paths = list(subdir.glob("*.jpg"))
33 | current_target_paths = [p for p in current_paths if len(str(p.name).split("_")) == 2]
34 | if use_ref and len(current_target_paths) < 9:
35 | # Skip if not enough target images
36 | continue
37 | all_paths.extend(current_paths)
38 | self.subdir_dict[subdir] = current_target_paths
39 |
40 | print(f"Percentile of valid subdirs: {len(self.subdir_dict) / len(subdirs)}")
41 | self.target_paths = [p for p in all_paths if len(str(p.name).split("_")) == 2]
42 | source_paths = [p for p in all_paths if len(str(p.name).split("_")) == 3]
43 | self.source_target_mappings = {path: [] for path in self.target_paths}
44 | for source_path in source_paths:
45 | # Remove last part of the path
46 | target_path = Path("_".join(str(source_path).split("_")[:-1]) + ".jpg")
47 | if target_path in self.source_target_mappings:
48 | self.source_target_mappings[target_path].append(source_path)
49 | print(f"Loaded {len(self.target_paths)} target images")
50 |
51 | self.clip_image_size = clip_image_size
52 |
53 | self.image_processor = image_processor
54 |
55 | self.max_crops = max_crops
56 |
57 | self.use_ref = use_ref
58 |
59 | self.ref_as_grid = ref_as_grid
60 |
61 | self.grid_size = grid_size
62 |
63 | self.sketch_prob = sketch_prob
64 |
65 | def __len__(self):
66 | return len(self.target_paths)
67 |
68 | def paste_on_background(self, image, background, min_scale=0.4, max_scale=0.8):
69 | # Calculate aspect ratio and determine resizing based on the smaller dimension of the background
70 | aspect_ratio = image.width / image.height
71 | scale = random.uniform(min_scale, max_scale)
72 | new_width = int(min(background.width, background.height * aspect_ratio) * scale)
73 | new_height = int(new_width / aspect_ratio)
74 |
75 | # Resize image and calculate position
76 | image = image.resize((new_width, new_height), resample=Image.LANCZOS)
77 | pos_x = random.randint(0, background.width - new_width)
78 | pos_y = random.randint(0, background.height - new_height)
79 |
80 | # Paste the image using its alpha channel as mask if present
81 | background.paste(image, (pos_x, pos_y), image if "A" in image.mode else None)
82 | return background
83 |
84 | def get_random_crop(self, image):
85 | crop_percent_x = random.uniform(0.8, 1.0)
86 | crop_percent_y = random.uniform(0.8, 1.0)
87 | # crop_percent_y = random.uniform(0.1, 0.7)
88 | crop_x = int(image.width * crop_percent_x)
89 | crop_y = int(image.height * crop_percent_y)
90 | x = random.randint(0, image.width - crop_x)
91 | y = random.randint(0, image.height - crop_y)
92 | return image.crop((x, y, x + crop_x, y + crop_y))
93 |
94 | def get_empty_image(self):
95 | empty_image = Image.new("RGB", (self.clip_image_size, self.clip_image_size), (255, 255, 255))
96 | return self.image_processor(empty_image)["pixel_values"][0]
97 |
98 | def __getitem__(self, i: int):
99 |
100 | out_dict = {}
101 |
102 | try:
103 | target_path = self.target_paths[i]
104 | image = Image.open(target_path).convert("RGB")
105 |
106 | input_parts = []
107 |
108 | source_paths = self.source_target_mappings[target_path]
109 | n_samples = random.randint(1, len(source_paths))
110 |
111 | n_samples = min(n_samples, self.max_crops)
112 | source_paths = random.sample(source_paths, n_samples)
113 |
114 | if random.random() < 0.1:
115 | # Use empty image, but maybe still pass reference
116 | source_paths = []
117 |
118 | if self.use_ref:
119 | subdir = target_path.parent
120 | # Take something from same dir
121 | potential_refs = list(set(self.subdir_dict[subdir]) - {target_path})
122 | # Choose 4 refs
123 | reference_paths = random.sample(potential_refs, self.grid_size**2)
124 | reference_images = [
125 | np.array(Image.open(reference_path).convert("RGB")) for reference_path in reference_paths
126 | ]
127 | # Concat all images as grid of 2x2
128 | reference_grid = np.stack(reference_images)
129 | grid_image = einops.rearrange(
130 | reference_grid,
131 | "(h w) h1 w1 c -> (h h1) (w w1) c",
132 | h=self.grid_size,
133 | )
134 | reference_image = Image.fromarray(grid_image).resize((512, 512))
135 |
136 | # Always add the reference image
137 | input_parts.append(reference_image)
138 |
139 | # Sample a subset
140 | for source_path in source_paths:
141 | source_image = Image.open(source_path).convert("RGB")
142 | if random.random() < 0.2:
143 | # Instead of using the source image, use a random crop from the target
144 | source_image = self.get_random_crop(source_image)
145 | if random.random() < 0.2:
146 | source_image = T.v2.RandomRotation(degrees=30, expand=True, fill=255)(source_image)
147 | object_with_background = Image.new("RGB", image.size, (255, 255, 255))
148 | self.paste_on_background(source_image, object_with_background, min_scale=0.8, max_scale=0.95)
149 | if self.sketch_prob > 0 and random.random() < self.sketch_prob:
150 | num_lines = random.randint(8, 15)
151 | object_with_background = bezier_utils.get_sketch(
152 | object_with_background,
153 | total_curves=num_lines,
154 | drop_line_prob=0.1,
155 | )
156 | input_parts.append(object_with_background)
157 |
158 | # Always pad to three parts for now
159 | actual_max_crops = self.max_crops + 1 if self.use_ref else self.max_crops
160 | while len(input_parts) < actual_max_crops:
161 | input_parts.append(
162 | Image.new(
163 | "RGB",
164 | (self.clip_image_size, self.clip_image_size),
165 | (255, 255, 255),
166 | )
167 | )
168 |
169 | except Exception as e:
170 | print(f"Error processing image: {e}")
171 | traceback.print_exc()
172 | empty_image = Image.new("RGB", (self.clip_image_size, self.clip_image_size), (255, 255, 255))
173 | image = empty_image
174 | actual_max_crops = self.max_crops + 1 if self.use_ref else self.max_crops
175 | input_parts = [empty_image] * (actual_max_crops)
176 |
177 | clip_target_image = self.image_processor(image)["pixel_values"][0]
178 | clip_parts = [self.image_processor(part)["pixel_values"][0] for part in input_parts]
179 |
180 | out_dict["crops"] = clip_parts
181 |
182 | return clip_target_image, out_dict
183 |
--------------------------------------------------------------------------------
/training/train_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from pathlib import Path
3 | from typing import List, Optional, Union
4 |
5 |
6 | @dataclass
7 | class TrainConfig:
8 | # Dataset path
9 | dataset_path: Union[Path, List[Path]] = Path("datasets/generated/generated_things")
10 | # Validation dataset path
11 | val_dataset_path: Path = Path("datasets/generated/generated_things_val")
12 | # The output directory where the model predictions and checkpoints will be written.
13 | output_dir: Path = Path("results/my_pit_model")
14 | # GPU device
15 | device: str = "cuda:0"
16 | # The resolution for input images, all the images will be resized to this size
17 | img_size: int = 1024
18 | # Batch size (per device) for the training dataloader
19 | train_batch_size: int = 1
20 | # Initial learning rate (after the potential warmup period) to use
21 | lr: float = 1e-5
22 | # Dataloader num workers.
23 | num_workers: int = 8
24 | # The beta1 parameter for the Adam optimizer.
25 | adam_beta1: float = 0.9
26 | # The beta2 parameter for the Adam optimizer
27 | adam_beta2: float = 0.999
28 | # Weight decay to use
29 | adam_weight_decay: float = 0.0 # 1e-2
30 | # Epsilon value for the Adam optimizer
31 | adam_epsilon: float = 1e-08
32 | # How often save images. Values less zero - disable saving
33 | log_image_frequency: int = 500
34 | # How often to run validation
35 | log_validation: int = 5000
36 | # The number of images to save during each validation
37 | n_val_images: int = 10
38 | # A seed for reproducible training
39 | seed: Optional[int] = None
40 | # The number of accumulation steps to use
41 | gradient_accumulation_steps: int = 1
42 | # Whether to use mixed precision training
43 | mixed_precision: Optional[str] = "fp16"
44 | # Log to wandb
45 | report_to: Optional[str] = "wandb"
46 | # The number of training steps to run
47 | max_train_steps: int = 1000000
48 | # Max grad for clipping
49 | max_grad_norm: float = 1.0
50 | # How often to save checkpoints
51 | checkpointing_steps: int = 5000
52 | # The path to resume from
53 | resume_from_path: Optional[Path] = None
54 | # The step to resume from, mainly for logging
55 | resume_from_step: Optional[int] = None
56 | # DiT number of layers
57 | num_layers: int = 8
58 | # DiT hidden dimensionality
59 | hidden_dim: int = 2048
60 | # DiT number of attention heads
61 | num_attention_heads: int = 32
62 | # Whether to use a reference grid
63 | use_ref: bool = False
64 | # Max number of crops
65 | max_crops: int = 3
66 | # Probability of converting to sketch
67 | sketch_prob: float = 0.0
68 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/PiT/73d26b0a4ab627e3f26c50763178bd75d0408cc2/utils/__init__.py
--------------------------------------------------------------------------------
/utils/bezier_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from PIL import Image
4 | import random
5 | from scipy.optimize import minimize
6 |
7 |
8 | def draw_contours(image, contours):
9 | """Draw contours on a blank image with random colors."""
10 | output = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
11 | for contour in contours:
12 | color = [random.randint(0, 255) for _ in range(3)]
13 | cv2.drawContours(output, [contour], -1, color, thickness=2)
14 | return output
15 |
16 |
17 | def contour_to_bezier(contour):
18 | """Convert an OpenCV contour to a cubic Bézier curve with 4 control points."""
19 | points = contour.reshape(-1, 2)
20 | P0 = points[0]
21 | P3 = points[-1]
22 |
23 | def bezier_point(t, control_points):
24 | P0, P1, P2, P3 = control_points
25 | return (1 - t) ** 3 * P0 + 3 * (1 - t) ** 2 * t * P1 + 3 * (1 - t) * t**2 * P2 + t**3 * P3
26 |
27 | def objective_function(params):
28 | P1 = np.array([params[0], params[1]])
29 | P2 = np.array([params[2], params[3]])
30 | control_points = [P0, P1, P2, P3]
31 | t_values = np.linspace(0, 1, len(points))
32 | bezier_points = np.array([bezier_point(t, control_points) for t in t_values])
33 | distances = np.sum((points - bezier_points) ** 2)
34 | return distances
35 |
36 | initial_P1 = P0 + (P3 - P0) / 3
37 | initial_P2 = P0 + 2 * (P3 - P0) / 3
38 | initial_guess = np.concatenate([initial_P1, initial_P2])
39 | result = minimize(objective_function, initial_guess, method="Nelder-Mead")
40 |
41 | P1 = np.array([result.x[0], result.x[1]])
42 | P2 = np.array([result.x[2], result.x[3]])
43 | return np.array([P0, P1, P2, P3])
44 |
45 |
46 | def visualize_result(img, control_points, num_points=100):
47 | """Draw a Bézier curve defined by control points."""
48 | t_values = np.linspace(0, 1, num_points)
49 | curve_points = []
50 | for t in t_values:
51 | x = (
52 | (1 - t) ** 3 * control_points[0][0]
53 | + 3 * (1 - t) ** 2 * t * control_points[1][0]
54 | + 3 * (1 - t) * t**2 * control_points[2][0]
55 | + t**3 * control_points[3][0]
56 | )
57 | y = (
58 | (1 - t) ** 3 * control_points[0][1]
59 | + 3 * (1 - t) ** 2 * t * control_points[1][1]
60 | + 3 * (1 - t) * t**2 * control_points[2][1]
61 | + t**3 * control_points[3][1]
62 | )
63 | curve_points.append([int(x), int(y)])
64 |
65 | curve_points = np.array(curve_points, dtype=np.int32)
66 | for i in range(len(curve_points) - 1):
67 | cv2.line(img, tuple(curve_points[i]), tuple(curve_points[i + 1]), (0, 0, 0), 2)
68 | return img
69 |
70 |
71 | def get_sketch(image, total_curves=10, drop_line_prob=0.0, pad=False):
72 | """
73 | Convert an image to a sketch made of Bézier curves.
74 |
75 | Args:
76 | image_path: Path to the input image
77 | total_curves: Total number of Bézier curves to use (default: 10)
78 |
79 | Returns:
80 | numpy.ndarray: Image with Bézier curves sketch
81 | """
82 | # Load and preprocess image
83 | image = np.array(image)
84 |
85 | # Pad image to square
86 | height, width, _ = image.shape
87 | if pad:
88 | max_side = max(height, width) + 20
89 | pad_h = (max_side - height) // 2
90 | pad_w = (max_side - width) // 2
91 |
92 | image = np.pad(
93 | image,
94 | ((pad_h, max_side - height - pad_h), (pad_w, max_side - width - pad_w), (0, 0)),
95 | mode="constant",
96 | constant_values=255,
97 | )
98 |
99 | # Convert to binary
100 | gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
101 | binary = np.where(gray_image < 255, 255, 0).astype(np.uint8)
102 |
103 | # Clean up binary image
104 | kernel = np.ones((5, 5), np.uint8)
105 | binary = cv2.erode(binary, kernel, iterations=4)
106 |
107 | # Get contours
108 | contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
109 |
110 | # Calculate curve allocation
111 | contour_lengths = [cv2.arcLength(contour, closed=True) for contour in contours]
112 | total_length = sum(contour_lengths)
113 | curve_allocation = [round((length / total_length) * total_curves) for length in contour_lengths]
114 |
115 | # Adjust allocation to match total_curves
116 | curve_allocation = np.clip(curve_allocation, 1, total_curves)
117 | while sum(curve_allocation) > total_curves:
118 | curve_allocation[np.argmax(curve_allocation)] -= 1
119 | while sum(curve_allocation) < total_curves:
120 | curve_allocation[np.argmin(curve_allocation)] += 1
121 |
122 | # Fit Bézier curves
123 | fitted_curves = []
124 | for contour, n_curves in zip(contours, curve_allocation):
125 | segment_length = len(contour) // n_curves
126 | if segment_length == 0:
127 | continue
128 | segments = [contour[i : i + segment_length] for i in range(0, len(contour), segment_length)]
129 |
130 | for segment in segments:
131 | control_points = contour_to_bezier(segment)
132 | fitted_curves.append(control_points)
133 |
134 | # Create final image
135 | curves_image = np.ones_like(image, dtype=np.uint8) * 255
136 | for curve in fitted_curves:
137 | if random.random() < drop_line_prob:
138 | continue
139 | curves_image = visualize_result(curves_image, curve)
140 | curves_image = Image.fromarray(curves_image)
141 | return curves_image
142 |
--------------------------------------------------------------------------------
/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | import textwrap
2 | from typing import List, Tuple, Optional
3 |
4 | import numpy as np
5 | from PIL import Image, ImageDraw, ImageFont
6 |
7 | LINE_WIDTH = 20
8 |
9 |
10 | def add_text_to_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0),
11 | min_lines: Optional[int] = None, add_below: bool = True):
12 | import textwrap
13 | lines = textwrap.wrap(text, width=LINE_WIDTH)
14 | if min_lines is not None and len(lines) < min_lines:
15 | if add_below:
16 | lines += [''] * (min_lines - len(lines))
17 | else:
18 | lines = [''] * (min_lines - len(lines)) + lines
19 | h, w, c = image.shape
20 | offset = int(h * .12)
21 | img = np.ones((h + offset * len(lines), w, c), dtype=np.uint8) * 255
22 | font_size = int(offset * .8)
23 |
24 | try:
25 | font = ImageFont.truetype("assets/OpenSans-Regular.ttf", font_size)
26 | textsize = font.getbbox(text)
27 | y_offset = (offset - textsize[3]) // 2
28 | except:
29 | font = ImageFont.load_default()
30 | y_offset = offset // 2
31 |
32 | if add_below:
33 | img[:h] = image
34 | else:
35 | img[-h:] = image
36 | img = Image.fromarray(img)
37 | draw = ImageDraw.Draw(img)
38 | for i, line in enumerate(lines):
39 | line_size = font.getbbox(line)
40 | text_x = (w - line_size[2]) // 2
41 | if add_below:
42 | draw.text((text_x, h + y_offset + offset * i), line, font=font, fill=text_color)
43 | else:
44 | draw.text((text_x, 0 + y_offset + offset * i), line, font=font, fill=text_color)
45 | return np.array(img)
46 |
47 |
48 | def create_table_plot(images: List[Image.Image], titles: List[str]=None, captions: List[str]=None) -> Image.Image:
49 | title_max_lines = np.max([len(textwrap.wrap(text, width=LINE_WIDTH)) for text in titles]) if titles is not None else 0
50 | caption_max_lines = np.max([len(textwrap.wrap(text, width=LINE_WIDTH)) for text in captions]) if captions is not None else 0
51 | out_images = []
52 | for i in range(len(images)):
53 | im = np.array(images[i])
54 | if titles is not None:
55 | im = add_text_to_image(im, titles[i], add_below=False, min_lines=title_max_lines)
56 | if captions is not None:
57 | im = add_text_to_image(im, captions[i], add_below=True, min_lines=caption_max_lines)
58 | out_images.append(im)
59 | image = Image.fromarray(np.concatenate(out_images, axis=1))
60 | return image
61 |
--------------------------------------------------------------------------------