├── README.md
├── UI_GUIDE.md
├── assets
├── icon.svg
├── pipe.png
├── polygons.gif
├── region_points.gif
├── regions.gif
├── time.png
└── ui_overview.png
├── region_utils
├── __init__.py
├── cycle_sde.py
├── drag.py
├── evaluator.py
└── ui_utils.py
├── requirements.txt
├── run_eval.py
└── ui.py
/README.md:
--------------------------------------------------------------------------------
1 | # RegionDrag: Fast Region-Based Image Editing with Diffusion Models (ECCV 2024)
2 | **Jingyi Lu1, [Xinghui Li](https://xinghui-li.github.io/)2, [Kai Han](https://www.kaihan.org/)1**
3 | [1Visual AI Lab, The University of Hong Kong](https://visailab.github.io/index.html); [2Active Vision Lab, University of Oxford](https://www.robots.ox.ac.uk/ActiveVision/)
4 |
5 | [](https://colab.research.google.com/drive/1pnq9t_1zZ8yL_Oba20eBLVZLp3glniBR?usp=sharing)
6 |
7 |
8 |
9 |
10 |
11 |
 |  |
12 |
13 | RegionDrag proposes to use pairs of **regions** instead of **points** (e.g. DragGAN, DragDiffusion) to drag image contents. Visit our [project page](https://visual-ai.github.io/regiondrag) for various region input examples.
14 | - **A region is equivalent to a large number of points**, providing richer input context and improving image consistency.
15 | - By using regions as input, we can reduce run time to **1.5 seconds** (close to 20-step SD image generation).
16 |
17 | During inference, the SD latent representations of the inputted image are extracted from 🔴 **RED** regions during inversion and mapped to 🔵 **BLUE** regions during denoising across multiple timesteps.
18 |
19 | ## Installation
20 | CUDA support is required to run our code, you can try our [Colab Demo](https://colab.research.google.com/drive/1pnq9t_1zZ8yL_Oba20eBLVZLp3glniBR?usp=sharing) for easy access to GPU resource.
21 | To locally install RegionDrag, run following using terminal:
22 | ```
23 | git clone https://github.com/LuJingyi-John/RegionDrag.git
24 | cd RegionDrag
25 | pip install -r requirements.txt
26 |
27 | # Support InstantDrag Backbone Now (https://github.com/SNU-VGILab/InstantDrag)
28 | git clone https://github.com/SNU-VGILab/InstantDrag instantdrag
29 | ```
30 |
31 | ## Run RegionDrag
32 | After installing the requirements, you can simply launch the user inferface through:
33 | ```
34 | python3 ui.py
35 | ```
36 | For detailed instructions to use our UI, check out our [User Guide](./UI_GUIDE.md).
37 |
38 | ## DragBench-SR & DragBench-DR
39 | To evaluate region-based editing, we introduce [DragBench-SR](https://github.com/ML-GSAI/SDE-Drag) and [DragBench-DR](https://github.com/Yujun-Shi/DragDiffusion/) (R is short for 'Region’), which are modified versions of DragBench-S (100 samples) and DragBench-D (205 samples). These benchmarks are consistent with their point-based counterparts but use regions instead of points to reflect user intention. You can download the dataset [HERE](https://drive.google.com/file/d/1rdi4Rqka8zqHTbPyhQYtFC2UdWvAeAGV/view?usp=sharing).
40 |
41 |
42 | ```
43 | drag_data/
44 | ├── dragbench-dr/
45 | │ ├── animals/
46 | │ │ ├── JH_2023-09-14-1820-16/
47 | │ │ │ ├── original_image.png
48 | │ │ │ ├── user_drag.png
49 | │ │ │ ├── meta_data.pkl
50 | │ │ │ └── meta_data_region.pkl
51 | │ │ └── ...
52 | │ └── ...
53 | └── dragbench-sr/
54 | ├── art_0/
55 | │ ├── original_image.png
56 | │ ├── user_drag.png
57 | │ ├── meta_data.pkl
58 | │ └── meta_data_region.pkl
59 | └── ...
60 | ```
61 | `meta_data.pkl` or `meta_data_region.pkl` include user interaction metadata in a dictionary format:
62 |
63 | ```
64 | {
65 | 'prompt': text_prompt describing output image,
66 | 'points': list of points [(x1, y1), (x2, y2), ..., (xn, yn)],
67 | handle points: (x1,y1), (x3,y3), ..., target points: (x2,y2), (x4,y4), ...,
68 | 'mask': a binary mask specifying editing area,
69 | }
70 | ```
71 |
72 | ## BibTeX
73 | ```
74 | @inproceedings{lu2024regiondrag,
75 | author = {Jingyi Lu and Xinghui Li and Kai Han},
76 | title = {RegionDrag: Fast Region-Based Image Editing with Diffusion Models},
77 | booktitle = {European Conference on Computer Vision (ECCV)},
78 | year = {2024},
79 | }
80 | ```
81 |
82 | ## Related links
83 | * [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
84 | * [DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing](https://github.com/Yujun-Shi/DragDiffusion/)
85 | * [MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing](https://ljzycmd.github.io/projects/MasaCtrl/)
86 | * [Emergent Correspondence from Image Diffusion](https://diffusionfeatures.github.io/)
87 | * [The Blessing of Randomness: SDE Beats ODE in General Diffusion-based Image Editing](https://github.com/ML-GSAI/SDE-Drag)
88 | * [InstantDrag: Improving Interactivity in Drag-based Image Editing](https://github.com/SNU-VGILab/InstantDrag)
89 |
90 | ## Acknowledgement
91 | Insightful discussions with Cheng Silin and Huang Xiaohu were instrumental in refining our methodology. The intuitive layout of the [DragDiffusion](https://github.com/Yujun-Shi/DragDiffusion/) project inspired our user interface design. Our SDE scheduler implementation builds upon the groundbreaking work by Shen Nie et al. in their [SDE-Drag](https://github.com/ML-GSAI/SDE-Drag) project.
92 |
93 |
94 |
--------------------------------------------------------------------------------
/UI_GUIDE.md:
--------------------------------------------------------------------------------
1 | # RegionDrag: Fast Region-Based Image Editing with Diffusion Models (ECCV 2024)
2 | **Jingyi Lu†, [Xinghui Li‡](https://xinghui-li.github.io/), [Kai Han†](https://www.kaihan.org/)**
3 | [Visual AI Lab, The University of Hong Kong†](https://visailab.github.io/index.html); [Active Vision Lab, University of Oxford‡](https://www.robots.ox.ac.uk/ActiveVision/)
4 |
5 | [](https://colab.research.google.com/drive/1pnq9t_1zZ8yL_Oba20eBLVZLp3glniBR?usp=sharing)
6 |
7 |
8 |
9 |
10 |
11 | ## Overview
12 | RegionDrag supports a variety of inputs. You can input regions or points to drag image contents from 🔴 **RED** to 🔵 **BLUE**. Below is a overview of different components in our UI. For detailed instructions to install RegionDrag, check out our [README](./README.md).
13 |
14 |
15 |
16 | ## Tips
17 | - Increasing the `Handle Noise Scale` can remove handle content. If it does not work, you can drag 🔴 some other contents to cover 🔵 the contents you would like to remove.
18 | - The image displayed in the `Results` column is a preview obtained from your inputs before using `Run Drag`. A better preview generally implies a better editing result.
19 | - If you find the preview image satisfactory, you can try changing the `Method` from `Encode then CP` to `CP then Encode`.
20 |
21 | ## Input pairs of regions
22 | - **Step 1:** Upload one image on the left, and click `Fit Canvas` to adjust size of image
23 | - **Step 2:** Add Regions (Draw mask on the left, and then click `Add Region`)
24 | - **Step 3:** Click `Run Drag`
25 |
26 |
27 |
28 | ## Input pairs of polygons
29 | - **Step 1:** Upload one image on the left, and click `Fit Canvas` to adjust size of image
30 | - **Step 2:** Click points on the middle image (You can select to input triangles or quadrilaterals)
31 | - **Step 3:** Click `Run Drag`
32 |
33 |
34 |
35 | ## Input regions and manipulate them by points
36 | - **Step 1:** Upload one image on the left, and click `Fit Canvas` to adjust size of image
37 | - **Step 2:** Draw masks to represent regions on the left
38 | - **Step 3:** Click points to control these regions in the middle
39 | - **Step 4:** Click `Run Drag`
40 |
41 |
--------------------------------------------------------------------------------
/assets/icon.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
856 |
--------------------------------------------------------------------------------
/assets/pipe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/pipe.png
--------------------------------------------------------------------------------
/assets/polygons.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/polygons.gif
--------------------------------------------------------------------------------
/assets/region_points.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/region_points.gif
--------------------------------------------------------------------------------
/assets/regions.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/regions.gif
--------------------------------------------------------------------------------
/assets/time.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/time.png
--------------------------------------------------------------------------------
/assets/ui_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/ui_overview.png
--------------------------------------------------------------------------------
/region_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/region_utils/__init__.py
--------------------------------------------------------------------------------
/region_utils/cycle_sde.py:
--------------------------------------------------------------------------------
1 | # This file is developed based on the work of [ML-GSAI/SDE-Drag],
2 | # which can be found at https://github.com/ML-GSAI/SDE-Drag
3 |
4 | import torch
5 | import random
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision import transforms
9 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler
10 |
11 | def load_model(version="v1-5", torch_device='cuda', torch_dtype=torch.float16, verbose=True):
12 | pipe_paths = {
13 | 'v1-5' : "runwayml/stable-diffusion-v1-5",
14 | 'v2-1' : "stabilityai/stable-diffusion-2-1",
15 | 'xl' : "stabilityai/stable-diffusion-xl-base-1.0"
16 | }
17 | pipe_path = pipe_paths.get(version, pipe_paths['v1-5'])
18 | pipe = StableDiffusionPipeline if version in ['v1-5', 'v2-1'] else StableDiffusionXLPipeline
19 |
20 | if verbose:
21 | print(f'Loading model from {pipe_path}.')
22 | pipe = pipe.from_pretrained(pipe_path, torch_dtype=torch_dtype).to(torch_device)
23 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
24 |
25 | # IP-Adaptor
26 | if version in ['v1-5', 'v2-1']:
27 | subfolder, weight_name, ip_adapter_scale = "models", "ip-adapter-plus_sd15.bin", 0.5
28 | else:
29 | subfolder, weight_name, ip_adapter_scale = "sdxl_models", "ip-adapter_sdxl.bin", 0.6
30 | pipe.load_ip_adapter("h94/IP-Adapter", subfolder=subfolder, weight_name=weight_name)
31 | pipe.set_ip_adapter_scale(ip_adapter_scale)
32 |
33 | tokenizer_2 = pipe.tokenizer_2 if version == 'xl' else None
34 | text_encoder_2 = pipe.text_encoder_2 if version == 'xl' else None
35 | return pipe.vae, pipe.tokenizer, pipe.text_encoder, pipe.unet, pipe.scheduler, pipe.feature_extractor, pipe.image_encoder, tokenizer_2, text_encoder_2
36 |
37 | @torch.no_grad()
38 | def get_text_embed(prompt: list, tokenizer, text_encoder, tokenizer_2=None, text_encoder_2=None, torch_device='cuda'):
39 | text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
40 | prompt_embeds = text_encoder(text_input.input_ids.to(torch_device), output_hidden_states=True)
41 | pooled_prompt_embeds, prompt_embeds = prompt_embeds[0], prompt_embeds.hidden_states[-2]
42 |
43 | if tokenizer_2 is not None and text_encoder_2 is not None:
44 | text_input = tokenizer_2(prompt, padding="max_length", max_length=tokenizer_2.model_max_length, truncation=True, return_tensors="pt")
45 | prompt_embeds_2 = text_encoder_2(text_input.input_ids.to(torch_device), output_hidden_states=True)
46 | pooled_prompt_embeds, prompt_embeds_2 = prompt_embeds_2[0], prompt_embeds_2.hidden_states[-2]
47 | prompt_embeds = torch.cat([prompt_embeds, prompt_embeds_2], dim=-1)
48 |
49 | return pooled_prompt_embeds, prompt_embeds
50 |
51 | @torch.no_grad()
52 | def get_img_latent(image, vae, torch_device='cuda', dtype=torch.float16, size=None):
53 | # upcast vae for sdxl, attention blocks can be in torch.float16
54 | upcast_dtype = torch.float32 if 'xl-base-1.0' in vae.config._name_or_path and dtype == torch.float16 else dtype
55 | if dtype == torch.float16:
56 | vae = vae.to(upcast_dtype)
57 | for module in [vae.post_quant_conv, vae.decoder.conv_in, vae.decoder.mid_block]:
58 | module = module.to(dtype)
59 |
60 | image = Image.open(image).convert('RGB') if isinstance(image, str) else Image.fromarray(image)
61 | image = image.resize(size) if size else image
62 | image = transforms.ToTensor()(image).unsqueeze(0).to(torch_device, upcast_dtype)
63 | latents = vae.encode(image * 2 - 1).latent_dist.sample() * 0.18215
64 | return latents.to(dtype)
65 |
66 | def set_seed(seed):
67 | torch.manual_seed(seed)
68 | random.seed(seed)
69 | np.random.seed(seed)
70 |
71 | torch.backends.cudnn.deterministic = True
72 | torch.backends.cudnn.benchmark = False
73 |
74 | class Sampler():
75 | def __init__(self, unet, scheduler, num_steps=100):
76 | scheduler.set_timesteps(num_steps)
77 | self.num_inference_steps = num_steps
78 | self.num_train_timesteps = len(scheduler)
79 |
80 | self.alphas = scheduler.alphas
81 | self.alphas_cumprod = scheduler.alphas_cumprod
82 |
83 | self.final_alpha_cumprod = torch.tensor(1.0)
84 | self.initial_alpha_cumprod = torch.tensor(1.0)
85 |
86 | self.unet = unet
87 |
88 | def get_eps(self, img, timestep, guidance_scale, text_embeddings, lora_scale=None, added_cond_kwargs=None):
89 | guidance_scale = max(1, guidance_scale)
90 |
91 | text_embeddings = text_embeddings if guidance_scale > 1. else text_embeddings[-1:]
92 | latent_model_input = torch.cat([img] * 2) if guidance_scale > 1. else img
93 | cross_attention_kwargs = None if lora_scale is None else {"scale": lora_scale}
94 |
95 | if guidance_scale == 1. and added_cond_kwargs is not None:
96 | added_cond_kwargs = {k: v[-1:] for k, v in added_cond_kwargs.items()}
97 |
98 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample
99 |
100 | if guidance_scale > 1.:
101 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
102 | elif guidance_scale == 1.:
103 | noise_pred_text = noise_pred
104 | noise_pred_uncond = 0.
105 | else:
106 | raise NotImplementedError(guidance_scale)
107 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
108 |
109 | return noise_pred
110 |
111 | def sample(self, timestep, sample, guidance_scale, text_embeddings, sde=False, noise=None, eta=1., lora_scale=None, added_cond_kwargs=None):
112 | eps = self.get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale, added_cond_kwargs)
113 |
114 | prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
115 |
116 | alpha_prod_t = self.alphas_cumprod[timestep]
117 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
118 |
119 | beta_prod_t = 1 - alpha_prod_t
120 |
121 | sigma_t = eta * ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5) * (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5) if sde else 0
122 |
123 | pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5)
124 | pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t ** 2) ** (0.5)
125 |
126 | noise = torch.randn_like(sample, device=sample.device) if noise is None else noise
127 | img = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise
128 |
129 | return img
130 |
131 | def forward_sde(self, timestep, sample, guidance_scale, text_embeddings, eta=1., lora_scale=None, added_cond_kwargs=None):
132 | prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
133 |
134 | alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod
135 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
136 |
137 | beta_prod_t_prev = 1 - alpha_prod_t_prev
138 |
139 | x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** (0.5) * torch.randn_like(sample, device=sample.device)
140 | eps = self.get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale, added_cond_kwargs)
141 |
142 | sigma_t_prev = eta * ((1 - alpha_prod_t) / (1 - alpha_prod_t_prev)) ** (0.5) * (1 - alpha_prod_t_prev / alpha_prod_t) ** (0.5)
143 |
144 | pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5)
145 | pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev ** 2) ** (0.5)
146 |
147 | noise = (sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps) / sigma_t_prev
148 |
149 | return x_prev, noise
150 |
151 | def forward_ode(self, timestep, sample, guidance_scale, text_embeddings, lora_scale=None, added_cond_kwargs=None):
152 | prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
153 |
154 | alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod
155 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
156 |
157 | beta_prod_t = 1 - alpha_prod_t
158 |
159 | eps = self.get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale, added_cond_kwargs)
160 | pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5)
161 | pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * eps
162 |
163 | img = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
164 |
165 | noise = None
166 | return img, noise
--------------------------------------------------------------------------------
/region_utils/drag.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import numpy as np
4 | import pickle
5 | from PIL import Image
6 | import cv2
7 | from tqdm import tqdm
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from .cycle_sde import Sampler, get_img_latent, get_text_embed, load_model, set_seed
13 |
14 | # select version from v1-5 (recommended), v2-1, xl
15 | sd_version = 'v1-5'
16 |
17 | # --- To include: InstantDrag (https://github.com/SNU-VGILab/InstantDrag) --- #
18 | import sys
19 | sys.path.append('instantdrag/')
20 | if not os.path.exists('instantdrag/utils/__init__.py'):
21 | open('instantdrag/utils/__init__.py', 'a').close()
22 |
23 | from huggingface_hub import snapshot_download
24 | from instantdrag.demo.demo_utils import InstantDragPipeline
25 | os.makedirs("./checkpoints", exist_ok=True)
26 | snapshot_download("alex4727/InstantDrag", local_dir="./checkpoints")
27 |
28 | def scale_schedule(begin, end, n, length, type='linear'):
29 | if type == 'constant':
30 | return end
31 | elif type == 'linear':
32 | return begin + (end - begin) * n / length
33 | elif type == 'cos':
34 | factor = (1 - math.cos(n * math.pi / length)) / 2
35 | return (1 - factor) * begin + factor * end
36 | else:
37 | raise NotImplementedError(type)
38 |
39 | def get_meta_data(meta_data_path):
40 | with open(meta_data_path, 'rb') as file:
41 | meta_data = pickle.load(file)
42 | prompt = meta_data['prompt']
43 | mask = meta_data['mask']
44 | points = meta_data['points']
45 | source = points[0:-1:2]
46 | target = points[1::2]
47 | return prompt, mask, source, target
48 |
49 | def get_drag_data(data_path):
50 | ori_image_path = os.path.join(data_path, 'original_image.png')
51 | meta_data_path = os.path.join(data_path, 'meta_data_region.pkl')
52 |
53 | original_image = np.array(Image.open(ori_image_path))
54 | prompt, mask, source, target = get_meta_data(meta_data_path)
55 |
56 | return {
57 | 'ori_image' : original_image, 'preview' : original_image, 'prompt' : prompt,
58 | 'mask' : mask, 'source' : np.array(source), 'target' : np.array(target)
59 | }
60 |
61 | def reverse_and_repeat_every_n_elements(lst, n, repeat=1):
62 | """
63 | Reverse every n elements in a given list, then repeat the reversed segments
64 | the specified number of times.
65 | Example:
66 | >>> reverse_and_repeat_every_n_elements([1, 2, 3, 4, 5, 6, 7, 8, 9], 3, 2)
67 | [3, 2, 1, 3, 2, 1, 6, 5, 4, 6, 5, 4, 9, 8, 7, 9, 8, 7]
68 | """
69 | if not lst or n < 1:
70 | return lst
71 | return [element for i in range(0, len(lst), n) for _ in range(repeat) for element in reversed(lst[i:i+n])]
72 |
73 | def get_border_points(points):
74 | x_max, y_max = np.amax(points, axis=0)
75 | mask = np.zeros((y_max+1, x_max+1), np.uint8)
76 | mask[points[:, 1], points[:, 0]] = 1
77 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
78 | border_points = np.concatenate([contour[:, 0, :] for contour in contours], axis=0)
79 | return border_points
80 |
81 | def postprocess(vae, latent, ori_image, mask):
82 | dtype = latent.dtype
83 | upcast_dtype = torch.float32 if 'xl-base-1.0' in vae.config._name_or_path and dtype == torch.float16 else dtype
84 | H, W = ori_image.shape[:2]
85 |
86 | if dtype == torch.float16:
87 | vae = vae.to(upcast_dtype)
88 | for module in [vae.post_quant_conv, vae.decoder.conv_in, vae.decoder.mid_block]:
89 | module = module.to(dtype)
90 |
91 | image = vae.decode(latent / 0.18215).sample / 2 + 0.5
92 | image = (image.clamp(0, 1).permute(0, 2, 3, 1)[0].cpu().numpy() * 255).astype(np.uint8)
93 | image = cv2.resize(image, (W, H))
94 |
95 | if not np.all(mask == 1):
96 | image = np.where(mask[:, :, None], image, ori_image)
97 |
98 | return image
99 |
100 | def copy_and_paste(source_latents, target_latents, source, target):
101 | target_latents[0, :, target[:, 1], target[:, 0]] = source_latents[0, :, source[:, 1], source[:, 0]]
102 | return target_latents
103 |
104 | def blur_source(latents, noise_scale, source):
105 | img_scale = (1 - noise_scale ** 2) ** (0.5) if noise_scale < 1 else 0
106 | latents[0, :, source[:, 1], source[:, 0]] = latents[0, :, source[:, 1], source[:, 0]] * img_scale + \
107 | torch.randn_like(latents[0, :, source[:, 1], source[:, 0]]) * noise_scale
108 | return latents
109 |
110 | def ip_encode_image(feature_extractor, image_encoder, image):
111 | dtype = next(image_encoder.parameters()).dtype
112 | device = next(image_encoder.parameters()).device
113 |
114 | image = feature_extractor(image, return_tensors="pt").pixel_values.to(device=device, dtype=dtype)
115 | image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]
116 | uncond_image_enc_hidden_states = image_encoder(
117 | torch.zeros_like(image), output_hidden_states=True
118 | ).hidden_states[-2]
119 | image_embeds = torch.stack([uncond_image_enc_hidden_states, image_enc_hidden_states], dim=0)
120 |
121 | return [image_embeds]
122 |
123 | def forward(scheduler, sampler, steps, start_t, latent, text_embeddings, added_cond_kwargs, progress=tqdm, sde=True):
124 | forward_func = sampler.forward_sde if sde else sampler.forward_ode
125 | hook_latents = [latent,]; noises = []; cfg_scales = []
126 | start_t = int(start_t * steps)
127 |
128 | for index, t in enumerate(progress(scheduler.timesteps[(steps - start_t):].flip(dims=[0])), start=1):
129 | cfg_scale = scale_schedule(1, 1, index, steps, type='linear')
130 | latent, noise = forward_func(t, latent, cfg_scale, text_embeddings, added_cond_kwargs=added_cond_kwargs)
131 | hook_latents.append(latent); noises.append(noise); cfg_scales.append(cfg_scale)
132 |
133 | return hook_latents, noises, cfg_scales
134 |
135 | def backward(scheduler, sampler, steps, start_t, end_t, noise_scale, hook_latents, noises, cfg_scales, mask, text_embeddings, added_cond_kwargs, blur, source, target, progress=tqdm, latent=None, sde=True):
136 | start_t = int(start_t * steps)
137 | end_t = int(end_t * steps)
138 |
139 | latent = hook_latents[-1].clone() if latent is None else latent
140 | latent = blur_source(latent, noise_scale, blur)
141 |
142 | for t in progress(scheduler.timesteps[(steps-start_t- 1):-1]):
143 | hook_latent = hook_latents.pop()
144 | latent = copy_and_paste(hook_latent, latent, source, target) if t >= end_t else latent
145 | latent = torch.where(mask == 1, latent, hook_latent)
146 | latent = sampler.sample(t, latent, cfg_scales.pop(), text_embeddings, sde=sde, noise=noises.pop(), added_cond_kwargs=added_cond_kwargs)
147 | return latent
148 |
149 | def drag(drag_data, steps, start_t, end_t, noise_scale, seed, progress=tqdm, method='Encode then CP', save_path=''):
150 | set_seed(seed)
151 | device = 'cuda'
152 | ori_image, preview, prompt, mask, source, target = drag_data.values()
153 |
154 | if method in ('Encode then CP', 'CP then Encode'):
155 | global vae, tokenizer, text_encoder, unet, scheduler, feature_extractor, image_encoder, tokenizer_2, text_encoder_2
156 | if 'vae' not in globals():
157 | vae, tokenizer, text_encoder, unet, scheduler, feature_extractor, image_encoder, tokenizer_2, text_encoder_2 = load_model(sd_version)
158 |
159 | def copy_key_hook(module, input, output):
160 | keys.append(output)
161 | def copy_value_hook(module, input, output):
162 | values.append(output)
163 | def paste_key_hook(module, input, output):
164 | output[:] = keys.pop()
165 | def paste_value_hook(module, input, output):
166 | output[:] = values.pop()
167 |
168 | def register(do='copy'):
169 | do_copy = do == 'copy'
170 | key_hook, value_hook = (copy_key_hook, copy_value_hook) if do_copy else (paste_key_hook, paste_value_hook)
171 | key_handlers = []; value_handlers = []
172 | for block in (*sampler.unet.down_blocks, sampler.unet.mid_block, *sampler.unet.up_blocks):
173 | if not hasattr(block, 'attentions'):
174 | continue
175 | for attention in block.attentions:
176 | for tb in attention.transformer_blocks:
177 | key_handlers.append(tb.attn1.to_k.register_forward_hook(key_hook))
178 | value_handlers.append(tb.attn1.to_v.register_forward_hook(value_hook))
179 | return key_handlers, value_handlers
180 |
181 | def unregister(*handlers):
182 | for handler in handlers:
183 | handler.remove()
184 | torch.cuda.empty_cache()
185 |
186 | sde = encode_then_cp = method == 'Encode then CP'
187 | source = torch.from_numpy(source).to(device) if isinstance(source, np.ndarray) else source.to(device)
188 | target = torch.from_numpy(target).to(device) if isinstance(target, np.ndarray) else target.to(device)
189 | source = source // 8; target = target // 8 # from img scale to latent scale
190 |
191 | if encode_then_cp:
192 | blur_pts = source; copy_pts = source
193 | else:
194 | blur_pts = torch.cat([torch.from_numpy(get_border_points(target.cpu().numpy())).to(device), source], dim=0)
195 | copy_pts = target
196 | paste_pts = target
197 |
198 | latent = get_img_latent(ori_image, vae)
199 | preview_latent = get_img_latent(preview, vae) if not encode_then_cp else None
200 | sampler = Sampler(unet=unet, scheduler=scheduler, num_steps=steps)
201 |
202 | with torch.no_grad():
203 | neg_pooled_prompt_embeds, neg_prompt_embeds = get_text_embed("", tokenizer, text_encoder, tokenizer_2, text_encoder_2)
204 | neg_prompt_embeds = neg_prompt_embeds if sd_version == 'xl' else neg_pooled_prompt_embeds
205 | pooled_prompt_embeds, prompt_embeds = get_text_embed(prompt, tokenizer, text_encoder, tokenizer_2, text_encoder_2)
206 | prompt_embeds = prompt_embeds if sd_version == 'xl' else pooled_prompt_embeds
207 | prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0)
208 | pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
209 |
210 | image_embeds = ip_encode_image(feature_extractor, image_encoder, ori_image)
211 |
212 | H, W = ori_image.shape[:2]
213 | add_time_ids = torch.tensor([[H, W, 0, 0, H, W]]).to(prompt_embeds).repeat(2, 1)
214 | added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} if sd_version == 'xl' else {}
215 | added_cond_kwargs["image_embeds"] = image_embeds
216 |
217 | mask_pt = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(device)
218 | mask_pt = F.interpolate(mask_pt, size=latent.shape[2:]).expand_as(latent)
219 |
220 | if not encode_then_cp:
221 | hook_latents, noises, cfg_scales = forward(scheduler, sampler, steps, start_t, preview_latent, prompt_embeds, added_cond_kwargs, progress=progress, sde=sde)
222 |
223 | keys = []; values = []
224 | key_handlers, value_handlers = register(do='copy')
225 | if encode_then_cp:
226 | hook_latents, noises, cfg_scales = forward(scheduler, sampler, steps, start_t, latent, prompt_embeds, added_cond_kwargs, progress=progress, sde=sde)
227 | start_latent = None
228 | else:
229 | start_latent = forward(scheduler, sampler, steps, start_t, latent, prompt_embeds, added_cond_kwargs, progress=progress, sde=sde)[0][-1]
230 | unregister(*key_handlers, *value_handlers)
231 |
232 | keys = reverse_and_repeat_every_n_elements(keys, n=len(key_handlers))
233 | values = reverse_and_repeat_every_n_elements(values, n=len(value_handlers))
234 |
235 | key_handlers, value_handlers = register(do='paste')
236 | latent = backward(scheduler, sampler, steps, start_t, end_t, noise_scale, hook_latents, noises, cfg_scales, mask_pt, prompt_embeds, added_cond_kwargs, blur_pts, copy_pts, paste_pts, latent=start_latent, progress=progress, sde=sde)
237 | unregister(*key_handlers, *value_handlers)
238 |
239 | image = postprocess(vae, latent, ori_image, mask)
240 |
241 | elif method == 'InstantDrag':
242 | global instant_pipe
243 | if 'instant_pipe' not in globals():
244 | instant_pipe = InstantDragPipeline(seed, 'cuda', torch.float16)
245 | flowgen_ckpt = next((m for m in sorted(os.listdir("checkpoints/")) if "flowgen" in m), None)
246 | flowdiffusion_ckpt = next(f for f in sorted(os.listdir("checkpoints/")) if "flowdiffusion" in f)
247 |
248 | print('Unused parameters in utils.drag function: input, start_t, end_t, noise_scale, progress')
249 | selected_points = [point.tolist() for pair in zip(source, target) for point in pair]
250 | ori_H, ori_W = ori_image.shape[:2]; new_H, new_W = 512, 512
251 | selected_points = torch.tensor(selected_points) / torch.tensor([ori_W, ori_H]) * torch.tensor([new_W, new_H])
252 | image_guidance, flow_guidance, flowgen_output_scale = 1.5, 1.5, -1.0
253 | image = instant_pipe.run(cv2.resize(ori_image, (new_W, new_H)), selected_points.tolist(), flowgen_ckpt, flowdiffusion_ckpt, image_guidance,
254 | flow_guidance, flowgen_output_scale, steps, save_results=False)
255 | image = cv2.resize(image, (ori_W, ori_H))
256 |
257 | else:
258 | raise ValueError('Select method from InstantDrag, Encode then CP, CP then Encode')
259 |
260 | if save_path:
261 | os.makedirs(save_path, exist_ok=True)
262 | counter = 0
263 | file_root, file_extension = os.path.splitext('dragged_image.png')
264 | while True:
265 | test_name = f"{file_root} ({counter}){file_extension}" if counter != 0 else 'dragged_image.png'
266 | full_path = os.path.join(save_path, test_name)
267 | if not os.path.exists(full_path):
268 | Image.fromarray(image).save(full_path)
269 | break
270 | counter += 1
271 | torch.cuda.empty_cache()
272 | return image
--------------------------------------------------------------------------------
/region_utils/evaluator.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import os
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import cv2
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | import lpips
11 | from diffusers import StableDiffusionPipeline
12 | from transformers import AutoModel
13 |
14 | warnings.filterwarnings(action='ignore', category=UserWarning)
15 |
16 | def plot_matching_result(src_image, trg_image, src_points, trg_points, pred_trg_points, output_path=None, border_width=10):
17 | """
18 | Used to visualize dragging effect.
19 | """
20 | if src_points.shape != trg_points.shape or src_points.shape != pred_trg_points.shape:
21 | raise ValueError(f"points arrays must have the same shapes, got Source:{src_points.shape}, Target:{trg_points.shape}, Predicted:{pred_trg_points.shape}")
22 | if src_image.shape[0] != trg_image.shape[0]:
23 | raise ValueError(f"Source image and target image must have same height, got Source:{src_image.shape}, Target:{trg_image.shape}")
24 |
25 | # Create a border and combine images
26 | border = np.ones((src_image.shape[0], border_width, 3), dtype=np.uint8) * 255
27 | combined_image = np.concatenate((src_image, border, trg_image), axis=1)
28 |
29 | # Adjust target and predicted target points by the width of the source image and the border
30 | trg_points_adj = trg_points + np.array([src_image.shape[1] + border_width, 0])
31 | pred_trg_points_adj = pred_trg_points + np.array([src_image.shape[1] + border_width, 0])
32 |
33 | # Create the plot
34 | fig, ax = plt.subplots(figsize=(10, 5))
35 | ax.imshow(combined_image.astype(np.uint8))
36 |
37 | # Draw arrows and points
38 | for src_pt, trg_pt, pred_trg_pt in zip(src_points, trg_points_adj, pred_trg_points_adj):
39 | ax.scatter(src_pt[0], src_pt[1], color='red', s=20)
40 | ax.scatter(trg_pt[0], trg_pt[1], color='blue', s=20)
41 | ax.scatter(pred_trg_pt[0], pred_trg_pt[1], color='red', s=20)
42 |
43 | # Draw arrows from src to trg
44 | ax.arrow(src_pt[0], src_pt[1], trg_pt[0] - src_pt[0], trg_pt[1] - src_pt[1],
45 | head_width=5, head_length=10, fc='white', ec='white', length_includes_head=True, lw=2, alpha=0.8)
46 |
47 | # Draw arrows from pred_trg to trg
48 | ax.arrow(pred_trg_pt[0], pred_trg_pt[1], trg_pt[0] - pred_trg_pt[0], trg_pt[1] - pred_trg_pt[1],
49 | head_width=5, head_length=10, fc='white', ec='white', length_includes_head=True, lw=2, alpha=0.8)
50 |
51 | ax.axis('off')
52 |
53 | # Save or show the image
54 | if output_path is not None:
55 | directory = os.path.dirname(output_path)
56 | if directory != '':
57 | os.makedirs(directory, exist_ok=True)
58 | plt.savefig(output_path, bbox_inches='tight', dpi=300) # Save with higher resolution
59 |
60 | plt.show()
61 | plt.close()
62 |
63 | return fig if output_path is None else output_path
64 |
65 |
66 | def create_mask(src_points, trg_points, img_size):
67 | """
68 | Creates a batch of masks based on the distance of image pixels to batches of given points.
69 |
70 | Args:
71 | src_points (torch.Tensor): The source points coordinates of shape (N, 2) [x, y].
72 | trg_points (torch.Tensor): The target points coordinates of shape (N, 2) [x, y].
73 | img_size (tuple): The size of the image (height, width).
74 |
75 | Returns:
76 | torch.Tensor: A batch of boolean masks where True indicates the pixel is within the distance for each point pair.
77 | """
78 | src_points = src_points.float()
79 | trg_points = trg_points.float()
80 |
81 | h, w = img_size
82 | point_distances = ((src_points - trg_points).norm(dim=1) / (2**0.5)).clamp(min=5) # Multiplying by 1/sqrt(2)
83 |
84 | y_indices, x_indices = torch.meshgrid(
85 | torch.arange(h, device=src_points.device),
86 | torch.arange(w, device=src_points.device),
87 | indexing="ij"
88 | )
89 |
90 | # Expand grid to match the batch size (y_indices, x_indices: shape [N, H, W])
91 | y_indices = y_indices.expand(src_points.size(0), -1, -1)
92 | x_indices = x_indices.expand(src_points.size(0), -1, -1)
93 |
94 | distance_to_p0 = ((x_indices - src_points[:, None, None, 0])**2 + (y_indices - src_points[:, None, None, 1])**2).sqrt()
95 | distance_to_p1 = ((x_indices - trg_points[:, None, None, 0])**2 + (y_indices - trg_points[:, None, None, 1])**2).sqrt()
96 | masks = (distance_to_p0 < point_distances[:, None, None]) | (distance_to_p1 < point_distances[:, None, None]) # (N, H, W)
97 |
98 | return masks
99 |
100 | def nn_get_matches(src_featmaps, trg_featmaps, query, l2_norm=True, mask=None):
101 | '''
102 | Find the nearest neighbour matches for a given query from source feature maps in target feature maps.
103 |
104 | Args:
105 | src_featmaps (torch.Tensor): Source feature map with shape (1 x C x H x W).
106 | trg_featmaps (torch.Tensor): Target feature map with shape (1 x C x H x W).
107 | query (torch.Tensor): (x, y) coordinates with shape (N x 2), must be in the range of src_featmaps.
108 | l2_norm (bool): If True, apply L2 normalization to features.
109 | mask (torch.Tensor): Optional mask with shape (N x H x W).
110 |
111 | Returns:
112 | torch.Tensor: (x, y) coordinates of the top matches with shape (N x 2).
113 | '''
114 | # Extract features from the source feature map at the query points
115 | _, c, h, w = src_featmaps.shape # (1, C, H, W)
116 | query = query.long()
117 | src_feat = src_featmaps[0, :, query[:, 1], query[:, 0]] # (C, N)
118 |
119 | if l2_norm:
120 | src_feat = F.normalize(src_feat, p=2, dim=0)
121 | trg_featmaps = F.normalize(trg_featmaps, p=2, dim=1)
122 |
123 | trg_featmaps = trg_featmaps.view(c, -1) # flatten (C, H*W)
124 | similarity = torch.mm(src_feat.t(), trg_featmaps) # similarity shape: (N, H*W)
125 |
126 | if mask is not None:
127 | mask = mask.view(-1, h * w) # mask shape: (N, H*W)
128 | similarity = torch.where(mask, similarity, torch.full_like(similarity, -torch.inf))
129 |
130 | # Get the indices of the best matches
131 | best_match_idx = similarity.argmax(dim=-1) # best_match_idx shape: (N,)
132 |
133 | # Convert flat indices to 2D coordinates
134 | y_coords = best_match_idx // w # y_coords shape: (N,)
135 | x_coords = best_match_idx % w # x_coords shape: (N,)
136 | coords = torch.stack((x_coords, y_coords), dim=1) # coords shape: (N, 2)
137 |
138 | return coords.float() # Output shape: (N, 2)
139 |
140 | class SDFeaturizer(StableDiffusionPipeline):
141 | """Used to extract SD2-1 feature for semantic point matching (DIFT)."""
142 | @torch.no_grad()
143 | def __call__(
144 | self,
145 | img_tensor,
146 | t=261,
147 | ensemble=8,
148 | prompt=None,
149 | prompt_embeds=None
150 | ):
151 | device = self._execution_device
152 | latents = self.vae.encode(img_tensor).latent_dist.mode() * self.vae.config.scaling_factor
153 | latents = latents.expand(ensemble, -1, -1, -1)
154 | t = torch.tensor(t, dtype=torch.long, device=device)
155 | noise = torch.randn_like(latents)
156 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
157 |
158 | if prompt_embeds is None:
159 | if prompt is None:
160 | prompt = ""
161 | prompt_embeds = self.encode_prompt(
162 | prompt=prompt,
163 | device=device,
164 | num_images_per_prompt=1,
165 | do_classifier_free_guidance=False
166 | )[0]
167 | prompt_embeds = prompt_embeds.expand(ensemble, -1, -1)
168 |
169 | # Cache output of second upblock of unet
170 | unet_feature = []
171 | def hook(module, input, output):
172 | unet_feature.clear()
173 | unet_feature.append(output)
174 | handle = list(self.unet.children())[4][1].register_forward_hook(hook=hook)
175 | self.unet(latents_noisy, t, prompt_embeds)
176 | handle.remove()
177 |
178 | return unet_feature[0].mean(dim=0, keepdim=True)
179 |
180 | class DragEvaluator:
181 | def __init__(self):
182 | self.clip_loaded = False
183 | self.dino_loaded = False
184 | self.sd_loaded = False
185 | self.lpips_loaded = False
186 |
187 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
188 | self.dtype = torch.float16
189 |
190 | def load_dino(self):
191 | if not self.dino_loaded:
192 | dino_path = 'facebook/dinov2-base'
193 | self.dino_model = AutoModel.from_pretrained(dino_path).to(self.device).to(self.dtype)
194 | self.dino_loaded = True
195 |
196 | def load_sd(self):
197 | if not self.sd_loaded:
198 | sd_path = 'stabilityai/stable-diffusion-2-1'
199 | self.sd_feat = SDFeaturizer.from_pretrained(sd_path, torch_dtype=self.dtype).to(self.device)
200 | self.sd_loaded = True
201 |
202 | def load_lpips(self):
203 | if not self.lpips_loaded:
204 | self.loss_fn_alex = lpips.LPIPS(net='alex').to(self.device).to(self.dtype)
205 | self.lpips_loaded = True
206 |
207 | def preprocess_image(self, image):
208 | image = torch.from_numpy(np.array(image)).float() / 127.5 - 1 # Normalize to [-1, 1]
209 | image = image.unsqueeze(0).permute(0, 3, 1, 2) # Rearrange to (1, C, H, W)
210 | return image.to(self.device).to(self.dtype)
211 |
212 | @torch.no_grad()
213 | def compute_lpips(self, image1, image2):
214 | """
215 | Learned Perceptual Image Patch Similarity (LPIPS) metric. (https://richzhang.github.io/PerceptualSimilarity/)
216 | """
217 | self.load_lpips()
218 |
219 | image1 = self.preprocess_image(image1)
220 | image2 = self.preprocess_image(image2)
221 | image1 = F.interpolate(image1, (224,224), mode='bilinear')
222 | image2 = F.interpolate(image2, (224,224), mode='bilinear')
223 |
224 | return self.loss_fn_alex(image1, image2).item()
225 |
226 | def encode_image(self, image, method, prompt=None):
227 | if method == 'dino':
228 | featmap = self.dino_model(image).last_hidden_state[:, 1:, :].permute(0, 2, 1)
229 | featmap = featmap.view(1, -1, 60, 60) # 60 = 840 / 14
230 | elif method == 'sd':
231 | featmap = self.sd_feat(image, prompt=prompt)
232 | else:
233 | raise NotImplementedError('Only SD and DINO supported.')
234 | return featmap
235 |
236 | @torch.no_grad()
237 | def compute_distance(self, src_image, trg_image, src_kps, trg_kps, method, prompt=None, plot_path=None):
238 | """ Mean Distance Metric """
239 | if method == 'dino':
240 | self.load_dino()
241 | elif method == 'sd':
242 | self.load_sd()
243 | else:
244 | raise NotImplementedError('Only SD and DINO supported.')
245 |
246 | src_kps = torch.tensor(src_kps, device=self.device).to(torch.long) # (N, 2) N points
247 | trg_kps = torch.tensor(trg_kps, device=self.device).to(torch.long) # (N, 2) N points
248 |
249 | # Resize target image and scale target points when necessary
250 | if src_image.shape != trg_image.shape:
251 | src_img_h, src_img_w, _ = src_image.shape
252 | trg_img_h, trg_img_w, _ = trg_image.shape
253 | trg_image = cv2.resize(trg_image, (src_img_w, src_img_h))
254 |
255 | trg_kps = trg_kps * torch.tensor([src_img_w, src_img_h], device=self.device)
256 | trg_kps = trg_kps / torch.tensor([trg_img_w, trg_img_h], device=self.device)
257 | trg_kps = trg_kps.to(torch.long)
258 |
259 | image_h, image_w, _ = src_image.shape
260 | image_size = 840 if method == 'dino' else 768
261 |
262 | source_image = self.preprocess_image(src_image) # 1, 3, H, W
263 | target_image = self.preprocess_image(trg_image) # 1, 3, H, W
264 | source_image = F.interpolate(source_image, size=(image_size, image_size)) # 1, 3, img_size, img_size
265 | target_image = F.interpolate(target_image, size=(image_size, image_size)) # 1, 3, img_size, img_size
266 |
267 | src_featmap = self.encode_image(source_image, method=method, prompt=prompt)
268 | src_featmap = F.interpolate(src_featmap, size=(image_h, image_w))
269 |
270 | trg_featmap = self.encode_image(target_image, method=method, prompt=prompt)
271 | trg_featmap = F.interpolate(trg_featmap, size=(image_h, image_w))
272 |
273 | mask = create_mask(src_kps, trg_kps, (image_h, image_w))
274 | pred_trg_kps = nn_get_matches(src_featmap, trg_featmap, src_kps, l2_norm=True, mask=mask)
275 |
276 | distance = trg_kps - pred_trg_kps
277 | distance[:, 0] /= image_w; distance[:, 1] /= image_h
278 | distance = distance.norm(dim=-1).mean().item()
279 |
280 | if plot_path:
281 | plot_matching_result(src_image, trg_image, src_kps.cpu().numpy(),trg_kps.cpu().numpy(), pred_trg_kps.cpu().numpy(), output_path=plot_path)
282 |
283 | return distance
--------------------------------------------------------------------------------
/region_utils/ui_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 | import gradio as gr
7 | import torch
8 | import time
9 | import functools
10 |
11 | from .drag import drag
12 |
13 | # 1. Common utils
14 | def timeit(func):
15 | """Decorator to measure the execution time of a function."""
16 | @functools.wraps(func)
17 | def wrapper(*args, **kwargs):
18 | start_time = time.time()
19 | result = func(*args, **kwargs)
20 | end_time = time.time()
21 | elapsed_time = end_time - start_time
22 | print(f"Function {func.__name__!r} took {elapsed_time:.4f} seconds to complete.")
23 | return result
24 |
25 | return wrapper
26 |
27 | def get_W_H(max_length, aspect_ratio):
28 | height = int(max_length / aspect_ratio) if aspect_ratio >= 1 else max_length
29 | width = max_length if aspect_ratio >= 1 else int(max_length * aspect_ratio)
30 | height = int(height / 8) * 8
31 | width = int(width / 8) * 8
32 | return width, height
33 |
34 | def canvas_to_image_and_mask(canvas):
35 | """Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object."""
36 | image = canvas["image"].copy()
37 | mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy()
38 | return image, mask
39 |
40 | def draw_mask_border(image, mask):
41 | """Find the contours of shapes in the mask and draw them on the image."""
42 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
43 | cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
44 |
45 | def mask_image(image, mask, color=[255,0,0], alpha=0.3):
46 | """Apply a binary mask to an image, highlighting the masked regions with a specified color and transparency."""
47 | out = image.copy()
48 | out[mask == 1] = color
49 | out = cv2.addWeighted(out, alpha, image, 1-alpha, 0, out)
50 | return out
51 |
52 | def resize_image(canvas, canvas_length, image_length):
53 | """Fit the image to an appropriate size."""
54 | if canvas is None:
55 | return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3
56 |
57 | image = canvas_to_image_and_mask(canvas)[0]
58 | image_h0, image_w0, _ = image.shape
59 | image_w1, image_h1 = get_W_H(image_length, image_w0 / image_h0)
60 | image = cv2.resize(image, (image_w1, image_h1))
61 |
62 | # helpful when uploaded gradio image having width > length
63 | canvas_w1, canvas_h1 = (canvas_length, canvas_length) if image_h0 > image_w0 else get_W_H(canvas_length, image_w0 / image_h0)
64 | return (gr.Image(value=image, width=canvas_w1, height=canvas_h1),) * 3
65 |
66 | def wrong_upload():
67 | """Prevent user to upload an image on the second column"""
68 | gr.Warning('You should upload an image on the left.')
69 | return None
70 |
71 | def save_data(original_image, input_image, preview_image, mask, src_points, trg_points, prompt, data_path):
72 | """
73 | save following under `data_path` directory
74 | (1) original.png : original image
75 | (2) input_image.png : input image
76 | (3) preview_image.png : preview image
77 | (4) meta_data_mask.pkl : {'prompt' : text prompt,
78 | 'points' : [(x1,y1), (x2,y2), ..., (xn, yn)],
79 | 'mask' : binary mask np.uint8 (H, W)}
80 | (x1, y1), (x3, y3), ... are from source (handle) points
81 | (x2, y2), (x4, y4), ... are from target points
82 | """
83 | os.makedirs(data_path, exist_ok=True)
84 | img_path = os.path.join(data_path, 'original_image.png')
85 | input_img_path = os.path.join(data_path, 'user_drag_region.png')
86 | preview_img_path = os.path.join(data_path, 'preview.png')
87 | meta_data_path = os.path.join(data_path, 'meta_data_mask.pkl')
88 |
89 | Image.fromarray(original_image).save(img_path)
90 | Image.fromarray(input_image).save(input_img_path)
91 | Image.fromarray(preview_image).save(preview_img_path)
92 |
93 | points = [point for pair in zip(src_points, trg_points) for point in pair]
94 |
95 | meta_data = {
96 | 'prompt': prompt,
97 | 'points': points,
98 | 'mask': mask
99 | }
100 | with open(meta_data_path, 'wb') as file:
101 | pickle.dump(meta_data, file, protocol=pickle.HIGHEST_PROTOCOL)
102 |
103 | @torch.no_grad()
104 | def run_process(canvas, input_image, preview_image, src_points, trg_points, prompt, start_t, end_t, steps, noise_scale, data_path, method, seed, progress=gr.Progress()):
105 | if canvas is None:
106 | return None
107 |
108 | original_image, mask = canvas_to_image_and_mask(canvas)
109 | mask = np.ones_like(mask)
110 |
111 | if src_points is None or len(src_points) == 0:
112 | return original_image
113 |
114 | drag_data = {
115 | 'ori_image' : original_image,
116 | 'preview' : preview_image,
117 | 'prompt' : prompt,
118 | 'mask' : mask,
119 | 'source' : src_points,
120 | 'target' : trg_points
121 | }
122 |
123 | return drag(drag_data, steps, start_t, end_t, noise_scale, seed, progress.tqdm, method, data_path)
124 |
125 | # 2. mask utils (region pairs)
126 | def clear_all_m(length):
127 | """Used to initialize all inputs in ui's region pair tab."""
128 | return (gr.Image(value=None, height=length, width=length),) * 3 + \
129 | ([], "A photo of an object.", "output/default", 20, 0.6, None, None)
130 |
131 | def draw_input_image_m(canvas, selected_masks):
132 | """Draw an image reflecting user's intentions."""
133 | image = canvas_to_image_and_mask(canvas)[0]
134 | for i, mask in enumerate(selected_masks):
135 | color = [255, 0, 0] if i % 2 == 0 else [0, 0, 255]
136 | image = mask_image(image, mask, color=color, alpha=0.3)
137 | draw_mask_border(image, mask)
138 | return image
139 |
140 | @torch.no_grad()
141 | def region_pair_to_pts(src_region, trg_region, scale=1):
142 | """
143 | Perform dense mapping beween one source (handle) and one target region.
144 | `scale` is set to 1/8 for mapping in SD latent space.
145 | """
146 |
147 | def mask_min_max(tensor, mask, dim=None):
148 | """
149 | Compute the masked max or min of a tensor along a given dimension.
150 | """
151 | # Apply the mask by using a very small or very large number for min/max respectively
152 | masked_tensor = torch.where(mask, tensor, torch.inf)
153 | masked_min = torch.min(masked_tensor, dim=dim)[0] if dim is not None else torch.min(masked_tensor)
154 |
155 | masked_tensor = torch.where(mask, tensor, -torch.inf)
156 | masked_max = torch.max(masked_tensor, dim=dim)[0] if dim is not None else torch.max(masked_tensor)
157 | return masked_min, masked_max
158 |
159 | src_region = cv2.resize(src_region, (int(src_region.shape[1]*scale), int(src_region.shape[0]*scale)))
160 | trg_region = cv2.resize(trg_region, (int(trg_region.shape[1]*scale), int(trg_region.shape[0]*scale)))
161 |
162 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
163 | src_region = torch.from_numpy(src_region).to(device).bool()
164 | trg_region = torch.from_numpy(trg_region).to(device).bool()
165 |
166 | h, w = src_region.shape
167 | src_grid = trg_grid = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij'), dim=-1).to(device).float()
168 | trg_pts = trg_grid[torch.where(trg_region)]
169 |
170 | src_x_min, src_x_max = mask_min_max(src_grid[:, :, 1], mask=src_region)
171 | trg_x_min, trg_x_max = mask_min_max(trg_grid[:, :, 1], mask=trg_region)
172 |
173 | scale_x = (src_x_max - src_x_min) / (trg_x_max - trg_x_min).clamp(min=1e-4)
174 |
175 | trg_grid[:, :, 1] = ((trg_grid[:, :, 1] - trg_x_min) * scale_x + src_x_min)
176 | trg_grid[:, :, 1] = torch.where(trg_region, trg_grid[:, :, 1], 0)
177 |
178 | src_y_min, src_y_max = mask_min_max(src_grid[:, :, 0], mask=src_region, dim=0)
179 | trg_y_min, trg_y_max = mask_min_max(trg_grid[:, :, 0], mask=trg_region, dim=0)
180 | src_y_min, src_y_max = src_y_min[trg_grid[:, :, 1].int()], src_y_max[trg_grid[:, :, 1].int()]
181 |
182 | scale_y = (src_y_max - src_y_min) / (trg_y_max - trg_y_min).clamp(min=1e-4)
183 | trg_grid[:, :, 0] = ((trg_grid[:, :, 0] - trg_y_min) * scale_y + src_y_min)
184 | warp_trg_pts = trg_grid[torch.where(trg_region)]
185 |
186 | return warp_trg_pts[:, [1, 0]].int(), trg_pts[:, [1, 0]].int()
187 |
188 | def preview_out_image_m(canvas, selected_masks):
189 | """Preview the output image by directly copy-pasting pixel values."""
190 | if canvas is None:
191 | return None, None, None
192 | image = canvas_to_image_and_mask(canvas)[0]
193 |
194 | if len(selected_masks) < 2:
195 | return image, None, None
196 |
197 | src_regions = selected_masks[0:-1:2]
198 | trg_regions = selected_masks[1::2]
199 |
200 | src_points, trg_points = map(torch.cat, zip(*[region_pair_to_pts(src_region, trg_region) for src_region, trg_region in zip(src_regions, trg_regions)]))
201 | src_idx, trg_idx = src_points[:, [1, 0]].cpu().numpy(), trg_points[:, [1, 0]].cpu().numpy()
202 | image[tuple(trg_idx.T)] = image[tuple(src_idx.T)]
203 |
204 | src_points, trg_points = map(torch.cat, zip(*[region_pair_to_pts(src_region, trg_region, scale=1/8) for src_region, trg_region in zip(src_regions, trg_regions)]))
205 | return image, src_points*8, trg_points*8
206 |
207 | def add_mask(canvas, selected_masks):
208 | """Add a drawn mask, and draw input image"""
209 | if canvas is None:
210 | return None
211 | image, mask = canvas_to_image_and_mask(canvas)
212 | if len(selected_masks) >= 1 and (mask == 0).all():
213 | gr.Warning('Do not input empty region.')
214 | else:
215 | selected_masks.append(mask)
216 | return draw_input_image_m(canvas, selected_masks)
217 |
218 | def undo_mask(canvas, selected_masks):
219 | """Undo a drawn mask, and draw input image"""
220 | if len(selected_masks) > 0:
221 | selected_masks.pop()
222 | if canvas is None:
223 | return None
224 | return draw_input_image_m(canvas, selected_masks)
225 |
226 | def clear_masks(canvas, selected_masks):
227 | """Clear all drawn masks, and draw input image"""
228 | selected_masks.clear()
229 | if canvas is None:
230 | return None
231 | return draw_input_image_m(canvas, selected_masks)
232 |
233 | # 3. vertice utils (polygon pairs)
234 | def clear_all(length):
235 | """Used to initialize all inputs in ui's vertice pair tab."""
236 | return (gr.Image(value=None, height=length, width=length),) * 3 + \
237 | ([], [], "A photo of an object.", "output/default", 20, 0.6, None, None)
238 |
239 | def draw_input_image(canvas, selected_points, selected_shapes):
240 | """Draw input image with vertices."""
241 | # Extract the image and mask from the canvas
242 | image, mask = canvas_to_image_and_mask(canvas)
243 | if mask.sum() > 0:
244 | gr.Info('The drawn mask is not used.')
245 | mask = np.ones_like(mask)
246 | # If a mask is present (i.e., sum of mask values > 0), non-masked parts will be darkened
247 | masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) if mask.sum() > 0 else image
248 |
249 | def draw_circle(image, point, text, color):
250 | font = cv2.FONT_HERSHEY_SIMPLEX
251 | font_scale = 0.3
252 | font_color = (255, 255, 255)
253 | font_thickness = 1
254 | text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
255 | bottom_left_corner = (point[0] - text_size[0] // 2, point[1] + text_size[1] // 2)
256 |
257 | cv2.circle(image, tuple(point), 8, color, -1)
258 | cv2.circle(image, tuple(point), 8, [255, 255, 255])
259 | cv2.putText(image, text, bottom_left_corner, font, font_scale, font_color, font_thickness)
260 |
261 | def draw_polygon(image, points, shape, color):
262 | if len(points) != shape:
263 | return image
264 | mask = np.zeros(image.shape[:2], dtype=np.uint8)
265 | points = np.array(points).reshape((-1, 1, 2))
266 | cv2.fillPoly(mask, [points], 1)
267 | return mask_image(image, mask, color, alpha=0.3)
268 |
269 | if len(selected_points) == 0:
270 | return masked_img
271 |
272 | start_idx = 1
273 | for points, shape in zip(selected_points, selected_shapes):
274 | src_pts, trg_pts = points[:shape], points[shape:]
275 | masked_img = draw_polygon(masked_img, src_pts, shape, color=[255, 0, 0])
276 | masked_img = draw_polygon(masked_img, trg_pts, shape, color=[0, 0, 255])
277 |
278 | for i, src_pt in enumerate(src_pts, start=start_idx):
279 | draw_circle(masked_img, src_pt, str(i), [255, 102, 102])
280 | for i, trg_pt in enumerate(trg_pts, start=start_idx):
281 | draw_circle(masked_img, trg_pt, str(i), [102, 102, 255])
282 | start_idx = i + 1
283 |
284 | return masked_img
285 |
286 | def transform_polygon(src_points, trg_points, scale=1):
287 | """
288 | Perform dense mapping with source (handle) and target triangle or quadrilateral.
289 | """
290 | def get_points_inside_polygon(points):
291 | points = np.array(points, dtype=np.int32)
292 | x_max, y_max = np.amax(points, axis=0) + 1
293 | mask = np.zeros((y_max, x_max), dtype=np.uint8)
294 | cv2.fillPoly(mask, [points], 1)
295 | return np.column_stack(np.where(mask == 1))[:, [1, 0]]
296 |
297 | if len(trg_points) not in [3, 4]:
298 | raise NotImplementedError('Only triangles and quadrilaterals are implemented')
299 |
300 | src_points, trg_points = np.float32(src_points)*scale, np.float32(trg_points)*scale
301 | M = (cv2.getAffineTransform if len(trg_points) == 3 else cv2.getPerspectiveTransform)(trg_points, src_points)
302 | points_inside = get_points_inside_polygon(trg_points)
303 | warped_points = (cv2.transform if len(trg_points) == 3 else cv2.perspectiveTransform)(np.array([points_inside], dtype=np.float32), M)
304 |
305 | return warped_points[0].astype(np.int32), points_inside
306 |
307 | def preview_out_image(canvas, selected_points, selected_shapes):
308 | if canvas is None:
309 | return None, None, None
310 | image = canvas_to_image_and_mask(canvas)[0]
311 |
312 | selected_points = selected_points.copy()
313 | selected_shapes = selected_shapes.copy()
314 |
315 | if len(selected_points) == 0:
316 | return image, None, None
317 |
318 | if len(selected_points[-1]) != selected_shapes[-1] * 2:
319 | selected_points.pop()
320 | selected_shapes.pop()
321 | if len(selected_points) == 0:
322 | return image, None, None
323 |
324 | src_points, trg_points = map(np.concatenate, zip(*[transform_polygon(sp[:ss], sp[ss:]) for sp, ss in zip(selected_points, selected_shapes)]))
325 | src_idx, trg_idx = src_points[:, [1, 0]], trg_points[:, [1, 0]]
326 | image[tuple(trg_idx.T)] = image[tuple(src_idx.T)]
327 |
328 | src_points, trg_points = map(np.concatenate, zip(*[transform_polygon(sp[:ss], sp[ss:], scale=1/8) for sp, ss in zip(selected_points, selected_shapes)]))
329 | return image, src_points*8, trg_points*8
330 |
331 | def add_point(canvas, shape, selected_points, selected_shapes, evt: gr.SelectData):
332 | """Collect the selected point, and draw the input image."""
333 | if canvas is None:
334 | return None
335 |
336 | def is_valid_quadrilateral(p1, p2, p3, p4):
337 | def orientation(p, q, r):
338 | val = (q[1] - p[1]) * (r[0] - q[0]) - (q[0] - p[0]) * (r[1] - q[1])
339 | if val == 0:
340 | return 0 # collinear
341 | return 1 if val > 0 else 2 # clock or counterclock wise
342 |
343 | def do_intersect(a1, a2, b1, b2):
344 | o1 = orientation(a1, a2, b1)
345 | o2 = orientation(a1, a2, b2)
346 | o3 = orientation(b1, b2, a1)
347 | o4 = orientation(b1, b2, a2)
348 | return (o1 != o2 and o3 != o4) and not (o1 == 0 or o2 == 0 or o3 == 0 or o4 == 0)
349 |
350 | # Check for collinearity and intersection between non-adjacent edges
351 | return not (orientation(p1, p2, p3) == 0 or
352 | orientation(p1, p2, p4) == 0 or
353 | orientation(p2, p3, p4) == 0 or
354 | orientation(p1, p3, p4) == 0 or
355 | do_intersect(p1, p2, p3, p4) or
356 | do_intersect(p2, p3, p1, p4))
357 |
358 | if len(selected_points) == 0 or len(selected_points[-1]) == 2 * selected_shapes[-1]:
359 | selected_points.append([evt.index])
360 | selected_shapes.append(shape+3)
361 | else:
362 | selected_points[-1].append(evt.index)
363 | if selected_shapes[-1] == 4 and len(selected_points[-1]) % selected_shapes[-1] == 0:
364 | if not is_valid_quadrilateral(*selected_points[-1][-4:]):
365 | gr.Warning('The drawn quadrilateral is not valid.')
366 | selected_points[-1].pop()
367 |
368 | return draw_input_image(canvas, selected_points, selected_shapes)
369 |
370 | def update_shape(canvas, shape, selected_points, selected_shapes):
371 | """Allow users to switch between different shape options"""
372 | if canvas is None:
373 | return None
374 | if len(selected_points) == 0:
375 | return draw_input_image(canvas, selected_points, selected_shapes)
376 | if len(selected_points[-1]) == selected_shapes[-1] * 2:
377 | return draw_input_image(canvas, selected_points, selected_shapes)
378 |
379 | selected_shapes[-1] = shape + 3
380 | selected_points[-1] = selected_points[-1][ : (shape+3)*2]
381 | return draw_input_image(canvas, selected_points, selected_shapes)
382 |
383 | def undo_point(canvas, shape, selected_points, selected_shapes):
384 | """Remove the last added point, and draw the input image."""
385 | if canvas is None:
386 | return None
387 | if len(selected_points) == 0:
388 | return draw_input_image(canvas, selected_points, selected_shapes)
389 |
390 | selected_points[-1].pop()
391 | if len(selected_points[-1]) == 0:
392 | selected_points.pop()
393 | selected_shapes.pop()
394 | update_shape(canvas, shape, selected_points, selected_shapes)
395 | return draw_input_image(canvas, selected_points, selected_shapes)
396 |
397 | def clear_points(canvas, selected_points, selected_shapes):
398 | """Clear all the points"""
399 | selected_points.clear()
400 | selected_shapes.clear()
401 | if canvas is None:
402 | return None
403 | return draw_input_image(canvas, selected_points, selected_shapes)
404 |
405 | # 4. region utils (region + point pair)
406 | def draw_input_image_r(canvas, selected_points):
407 | image, mask = canvas_to_image_and_mask(canvas)
408 | image = mask_image(image, mask, color=[255, 0, 0], alpha=0.3)
409 | draw_mask_border(image, mask)
410 | for idx, point in enumerate(selected_points, start=1):
411 | if idx % 2 == 0:
412 | cv2.circle(image, tuple(point), 10, (0, 0, 255), -1)
413 | cv2.arrowedLine(image, last_point, point, (255, 255, 255), 4, tipLength=0.5)
414 | else:
415 | cv2.circle(image, tuple(point), 10, (255, 0, 0), -1)
416 | last_point = point
417 | return image
418 |
419 |
420 | def region_to_points(region_mask, selected_points, scale=1):
421 | """
422 | Process a region mask and a list of selected points to find corresponding
423 | source and target points within the region scaled by a factor.
424 | """
425 | def resize_region_mask(mask, scale_factor):
426 | """Resize the region mask by a scale factor."""
427 | H, W = mask.shape
428 | new_H, new_W = (int(H * scale_factor), int(W * scale_factor))
429 | return cv2.resize(mask, (new_W, new_H)), (new_H, new_W)
430 |
431 | def find_contours(mask):
432 | """Find contours in the mask."""
433 | return cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
434 |
435 | def scale_points(points, scale_factor):
436 | """Scale points by a scale factor."""
437 | return (np.array(points) * scale_factor).astype(np.int32)
438 |
439 | def draw_mask_from_contours(contour, shape):
440 | """Draws a mask from given contours."""
441 | mask = np.zeros(shape, dtype=np.uint8)
442 | contour = contour[:, np.newaxis, :] if contour.ndim == 2 else contour
443 | cv2.drawContours(mask, [contour], -1, color=255, thickness=cv2.FILLED)
444 | return mask
445 |
446 | def find_points_in_mask(mask):
447 | """Find the coordinates of non-zero points in the mask."""
448 | return np.column_stack(np.where(mask)).astype(np.int32)[:, [1, 0]]
449 |
450 | def filter_points_by_bounds(source_points, target_points, bounds):
451 | """Filter points by checking if they fall within the given bounds."""
452 | height, width = bounds
453 | src_condition = (
454 | (source_points[:, 0] >= 0) & (source_points[:, 0] < width) &
455 | (source_points[:, 1] >= 0) & (source_points[:, 1] < height)
456 | )
457 | trg_condition = (
458 | (target_points[:, 0] >= 0) & (target_points[:, 0] < width) &
459 | (target_points[:, 1] >= 0) & (target_points[:, 1] < height)
460 | )
461 | condition = src_condition & trg_condition
462 | return source_points[condition], target_points[condition]
463 |
464 | def find_matching_points(source_region, source_points, target_points):
465 | """Find source points in source_region and their matching target points."""
466 | match_indices = np.all(source_points[:, None] == source_region, axis=2).any(axis=1)
467 | return source_points[match_indices], target_points[match_indices]
468 |
469 | def interpolate_points(region_points, reference_points, directions, max_num=100):
470 | """Interpolate points within a region based on reference points and their directions."""
471 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
472 |
473 | # Convert input numpy arrays to PyTorch tensors and send them to the appropriate device
474 | region_points = torch.from_numpy(region_points).half().to(device)
475 | reference_points = torch.from_numpy(reference_points).half().to(device)
476 | directions = torch.from_numpy(directions).half().to(device)
477 |
478 | if len(reference_points) < 2:
479 | return (region_points + directions).int().to('cpu').numpy()
480 |
481 | if len(reference_points) > max_num:
482 | indices = torch.linspace(0, len(reference_points) - 1, steps=max_num).long().to(device)
483 | reference_points = reference_points[indices]
484 | directions = directions[indices]
485 |
486 | distance = torch.norm(region_points.unsqueeze(1) - reference_points.unsqueeze(0), dim=-1)
487 | _, indices = torch.sort(distance, dim=1)
488 | indices = indices[:, :min(4, reference_points.shape[0])]
489 |
490 | directions = torch.gather(directions.unsqueeze(0).expand(region_points.size(0), -1, -1), 1, indices.unsqueeze(-1).expand(-1, -1, directions.size(-1)))
491 |
492 | inv_distance = 1 / (torch.gather(distance, 1, indices) + 1e-4)
493 | weight = inv_distance / inv_distance.sum(dim=1, keepdim=True)
494 |
495 | estimated_direction = (weight.unsqueeze(-1) * directions).sum(dim=1)
496 | return (region_points + estimated_direction).round().int().to('cpu').numpy()
497 |
498 | resized_mask, new_size = resize_region_mask(region_mask, scale)
499 |
500 | contours = find_contours(resized_mask)
501 | source_points = scale_points(selected_points[0:-1:2], scale)
502 | target_points = scale_points(selected_points[1::2], scale)
503 |
504 | source_regions = [np.zeros((0, 2)),]; target_regions = [np.zeros((0, 2)),]
505 | for contour in contours:
506 | # find point pairs used to manipulate the region inside this contour
507 | source_contour = contour[:, 0, :]
508 | source_region_points = find_points_in_mask(draw_mask_from_contours(source_contour, new_size))
509 | source, target = find_matching_points(source_region_points, source_points, target_points)
510 |
511 | # interplote to find motion of contour points and points inside
512 | if len(source) == 0:
513 | continue
514 | target_contour = interpolate_points(source_contour, source, target - source)
515 | interpolated_target_points = interpolate_points(source_region_points, source, target - source)
516 |
517 | # similar to above, this step ensures that we can have a reference point for each point inside the target region
518 | target_region_points = find_points_in_mask(draw_mask_from_contours(target_contour, new_size))
519 | interpolated_source_points = interpolate_points(target_region_points, interpolated_target_points, source_region_points - interpolated_target_points)
520 |
521 | filtered_source, filtered_target = filter_points_by_bounds(interpolated_source_points, target_region_points, new_size)
522 | source_regions.append(filtered_source)
523 | target_regions.append(filtered_target)
524 |
525 | return np.concatenate(source_regions).astype(np.int32), np.concatenate(target_regions).astype(np.int32)
526 |
527 | def preview_out_image_r(canvas, selected_points):
528 | if canvas is None:
529 | return None
530 | image, region_mask = canvas_to_image_and_mask(canvas)
531 |
532 | if len(selected_points) < 2:
533 | return image, None, None
534 |
535 | src_points, trg_points = region_to_points(region_mask, selected_points)
536 | image[trg_points[:, 1], trg_points[:, 0]] = image[src_points[:, 1], src_points[:, 0]]
537 | src_points, trg_points = region_to_points(region_mask, selected_points, scale=1/8)
538 | return image, src_points*8, trg_points*8
539 |
540 | def add_point_r(canvas, selected_points, evt: gr.SelectData):
541 | if canvas is None:
542 | return None
543 | selected_points.append(evt.index)
544 | return draw_input_image_r(canvas, selected_points)
545 |
546 | def undo_point_r(canvas, selected_points):
547 | if canvas is None:
548 | return None
549 | if len(selected_points) > 0:
550 | selected_points.pop()
551 | return draw_input_image_r(canvas, selected_points)
552 |
553 | def clear_points_r(canvas, selected_points):
554 | if canvas is None:
555 | return None
556 | selected_points.clear()
557 | return draw_input_image_r(canvas, selected_points)
558 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.0
2 | diffusers==0.30.1
3 | transformers==4.40.2
4 | safetensors==0.4.3
5 | einops==0.7.0
6 | lpips==0.1.4
7 | cmake==3.25.0
8 | gradio==3.47.1
9 | matplotlib==3.7.1
10 | numpy==1.24.1
11 | opencv-python==4.8.0.76
12 | pandas==2.0.2
13 | Pillow==9.3.0
14 | torch==2.0.1
15 | torchaudio==2.0.2
16 | torchvision==0.15.2
17 | tqdm==4.65.0
18 | xformers==0.0.20
19 | yacs==0.1.8
--------------------------------------------------------------------------------
/run_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from tqdm import tqdm
5 | import gradio as gr
6 |
7 | from region_utils.drag import drag, get_drag_data, get_meta_data
8 | from region_utils.evaluator import DragEvaluator
9 |
10 | # Setting up the argument parser
11 | parser = argparse.ArgumentParser(description='Run the drag operation.')
12 | parser.add_argument('--data_dir', type=str, default='drag_data/dragbench-dr/') # OR 'drag_data/dragbench-sr/'
13 | args = parser.parse_args()
14 |
15 | evaluator = DragEvaluator()
16 | all_distances = []; all_lpips = []
17 |
18 | data_dir = args.data_dir
19 | data_dirs = [dirpath for dirpath, dirnames, _ in os.walk(data_dir) if not dirnames]
20 |
21 | start_t = 0.5
22 | end_t = 0.2
23 | steps = 20
24 | noise_scale = 1.
25 | seed = 42
26 |
27 | for data_path in tqdm(data_dirs):
28 | # Region-based Inputs for Editing
29 | drag_data = get_drag_data(data_path)
30 | ori_image = drag_data['ori_image']
31 | out_image = drag(drag_data, steps, start_t, end_t, noise_scale, seed, progress=gr.Progress())
32 |
33 | # Point-based Inputs for Evaluation
34 | meta_data_path = os.path.join(data_path, 'meta_data.pkl')
35 | prompt, _, source, target = get_meta_data(meta_data_path)
36 |
37 | all_distances.append(evaluator.compute_distance(ori_image, out_image, source, target, method='sd', prompt=prompt))
38 | all_lpips.append(evaluator.compute_lpips(ori_image, out_image))
39 |
40 | if all_distances:
41 | mean_dist = torch.tensor(all_distances).mean().item()
42 | mean_lpips = torch.tensor(all_lpips).mean().item()
43 | print(f'MD: {mean_dist:.4f}\nLPIPS: {mean_lpips:.4f}\n')
--------------------------------------------------------------------------------
/ui.py:
--------------------------------------------------------------------------------
1 | import torch
2 | if not torch.cuda.is_available():
3 | raise RuntimeError('CUDA is not available, but required.')
4 |
5 | import gradio as gr
6 | from region_utils.ui_utils import *
7 | from region_utils.drag import sd_version
8 |
9 | LENGTH = 400 # length of image in Gradio App, you can adjust it according to your screen size
10 | GEN_SIZE = {'v1-5': 512, 'v2-1': 768, 'xl': 1024}[sd_version] # Default generated image size
11 |
12 | def main():
13 | with gr.Blocks() as demo:
14 | gr_length = gr.Number(value=LENGTH, visible=False, precision=0)
15 | gr_gen_size = gr.Number(value=GEN_SIZE, visible=False, precision=0)
16 |
17 | selected_masks = gr.State(value=[])
18 | src_points_m = gr.State(value=None); trg_points_m = gr.State(value=None)
19 |
20 | selected_points = gr.State(value=[])
21 | selected_shapes = gr.State(value=[])
22 | src_points = gr.State(value=None); trg_points = gr.State(value=None)
23 |
24 | selected_points_r = gr.State(value=[])
25 | src_points_r = gr.State(value=None); trg_points_r = gr.State(value=None)
26 |
27 | seed = gr.Number(value=42, label="Generation Seed", precision=0, visible=False)
28 | start_t = gr.Number(value=0.5, visible=False)
29 | end_t = gr.Number(value=0.2, visible=False)
30 |
31 | # layout definition
32 | with gr.Row():
33 | gr.Markdown("""
34 | # Official Implementation of [RegionDrag](https://arxiv.org/abs/2407.18247)
35 | #### Explore our detailed [User Guide](https://github.com/LuJingyi-John/RegionDrag/blob/main/README.md) for interface instructions.
36 | """)
37 |
38 | with gr.Row():
39 | with gr.Tab(label='Region pairs'):
40 | with gr.Row():
41 | with gr.Column():
42 | gr.Markdown("""1. Upload image and add regions
""")
43 | canvas_m = gr.Image(type="numpy", tool="sketch", label=" ", height=LENGTH, width=LENGTH)
44 | with gr.Row():
45 | resize_button_m = gr.Button("Fit Canvas")
46 | add_mask_button = gr.Button("Add Region")
47 | with gr.Column():
48 | gr.Markdown("""2. View Input
""")
49 | input_image_m = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=False)
50 | with gr.Row():
51 | undo_mask_button = gr.Button("Undo Region")
52 | clear_mask_button = gr.Button("Clear Region")
53 | with gr.Column():
54 | gr.Markdown("""Results
""")
55 | output_image_m = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=False)
56 | with gr.Row():
57 | run_button_m = gr.Button("Run Drag")
58 | clear_all_button_m = gr.Button("Clear All")
59 |
60 | with gr.Tab(label='Polygon pairs'):
61 | with gr.Row():
62 | with gr.Column():
63 | gr.Markdown("""1. Upload image
""")
64 | canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=LENGTH, width=LENGTH, interactive=True)
65 | with gr.Row():
66 | resize_button = gr.Button("Fit Canvas")
67 | with gr.Column():
68 | gr.Markdown("""2. Click vertices
""")
69 | input_image = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=True)
70 | with gr.Row():
71 | undo_point_button = gr.Button("Undo Point")
72 | clear_point_button = gr.Button("Clear Point")
73 | with gr.Column():
74 | gr.Markdown("""Results
""")
75 | output_image = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=False)
76 | with gr.Row():
77 | run_button = gr.Button("Run Drag")
78 | clear_all_button = gr.Button("Clear All")
79 | shape = gr.Radio(choices=['▲ Tri', '■ Quad'], value='■ Quad', type='index', label='Mask Shape', interactive=True)
80 |
81 | with gr.Tab(label='Region + points'):
82 | with gr.Row():
83 | with gr.Column():
84 | gr.Markdown("""1. Upload image and draw regions
""")
85 | canvas_r = gr.Image(type="numpy", tool="sketch", label=" ", height=LENGTH, width=LENGTH)
86 | with gr.Row():
87 | resize_button_r = gr.Button("Fit Canvas")
88 | with gr.Column():
89 | gr.Markdown("""2. Click points to control regions
""")
90 | input_image_r = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=True)
91 | with gr.Row():
92 | undo_point_button_r = gr.Button("Undo Point")
93 | clear_point_button_r = gr.Button("Clear Point")
94 | with gr.Column():
95 | gr.Markdown("""Results
""")
96 | output_image_r = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=False)
97 | with gr.Row():
98 | run_button_r = gr.Button("Run Drag")
99 | clear_all_button_r = gr.Button("Clear All")
100 |
101 | with gr.Tab("Generation Parameters"):
102 | with gr.Row():
103 | prompt = gr.Textbox(label="Prompt describing output image (Optional)", value='A photo of an object.')
104 | data_path = gr.Textbox(value='output/default', label="Output path")
105 | with gr.Row():
106 | steps = gr.Slider(minimum=20, maximum=100, value=20, step=20, label='Sampling steps', interactive=True)
107 | noise_scale = gr.Slider(minimum=0, maximum=1.6, value=0.6, step=0.2, label='Handle Noise Scale', interactive=True) # alpha
108 | method = gr.Dropdown(choices=['InstantDrag', 'Encode then CP', 'CP then Encode'], value='InstantDrag', label='Method', interactive=True)
109 |
110 | clear_all_button_m.click(
111 | clear_all_m,
112 | [gr_length],
113 | [canvas_m, input_image_m, output_image_m, selected_masks, prompt, data_path, steps, noise_scale, src_points_m, trg_points_m]
114 | )
115 | canvas_m.clear(
116 | clear_all_m,
117 | [gr_length],
118 | [canvas_m, input_image_m, output_image_m, selected_masks, prompt, data_path, steps, noise_scale, src_points_m, trg_points_m]
119 | )
120 | resize_button_m.click(
121 | clear_masks,
122 | [canvas_m, selected_masks],
123 | [input_image_m]
124 | ).then(
125 | resize_image,
126 | [canvas_m, gr_length, gr_gen_size],
127 | [canvas_m, input_image_m, output_image_m]
128 | )
129 | add_mask_button.click(
130 | add_mask,
131 | [canvas_m, selected_masks],
132 | [input_image_m]
133 | ).then(
134 | preview_out_image_m,
135 | [canvas_m, selected_masks],
136 | [output_image_m, src_points_m, trg_points_m]
137 | )
138 | undo_mask_button.click(
139 | undo_mask,
140 | [canvas_m, selected_masks],
141 | [input_image_m]
142 | ).then(
143 | preview_out_image_m,
144 | [canvas_m, selected_masks],
145 | [output_image_m, src_points_m, trg_points_m]
146 | )
147 | clear_mask_button.click(
148 | clear_masks,
149 | [canvas_m, selected_masks],
150 | [input_image_m]
151 | ).then(
152 | preview_out_image_m,
153 | [canvas_m, selected_masks],
154 | [output_image_m, src_points_m, trg_points_m]
155 | )
156 | run_button_m.click(
157 | preview_out_image_m,
158 | [canvas_m, selected_masks],
159 | [output_image_m, src_points_m, trg_points_m]
160 | ).then(
161 | run_process,
162 | [canvas_m, input_image_m, output_image_m, src_points_m, trg_points_m, prompt, start_t, end_t, steps, noise_scale, data_path, method, seed],
163 | [output_image_m]
164 | )
165 |
166 | clear_all_button.click(
167 | clear_all,
168 | [gr_length],
169 | [canvas, input_image, output_image, selected_points, selected_shapes, prompt, data_path, steps, noise_scale, src_points, trg_points]
170 | )
171 | canvas.clear(
172 | clear_all,
173 | [gr_length],
174 | [canvas, input_image, output_image, selected_points, selected_shapes, prompt, data_path, steps, noise_scale, src_points, trg_points]
175 | )
176 | resize_button.click(
177 | clear_points,
178 | [canvas, selected_points, selected_shapes],
179 | [input_image]
180 | ).then(
181 | resize_image,
182 | [canvas, gr_length, gr_gen_size],
183 | [canvas, input_image, output_image]
184 | )
185 | canvas.edit(
186 | draw_input_image,
187 | [canvas, selected_points, selected_shapes],
188 | input_image
189 | ).then(
190 | preview_out_image,
191 | [canvas, selected_points, selected_shapes],
192 | [output_image, src_points, trg_points]
193 | )
194 | shape.change(
195 | update_shape,
196 | [canvas, shape, selected_points, selected_shapes],
197 | [input_image]
198 | ).then(
199 | preview_out_image,
200 | [canvas, selected_points, selected_shapes],
201 | [output_image, src_points, trg_points]
202 | )
203 | input_image.upload(
204 | wrong_upload,
205 | outputs=[input_image]
206 | )
207 | input_image.select(
208 | add_point,
209 | [canvas, shape, selected_points, selected_shapes],
210 | [input_image]
211 | ).then(
212 | preview_out_image,
213 | [canvas, selected_points, selected_shapes],
214 | [output_image, src_points, trg_points]
215 | )
216 | undo_point_button.click(
217 | undo_point,
218 | [canvas, shape, selected_points, selected_shapes],
219 | [input_image]
220 | ).then(
221 | preview_out_image,
222 | [canvas, selected_points, selected_shapes],
223 | [output_image, src_points, trg_points]
224 | )
225 | clear_point_button.click(
226 | clear_points,
227 | [canvas, selected_points, selected_shapes],
228 | [input_image]
229 | ).then(
230 | preview_out_image,
231 | [canvas, selected_points, selected_shapes],
232 | [output_image, src_points, trg_points]
233 | )
234 | run_button.click(
235 | preview_out_image,
236 | [canvas, selected_points, selected_shapes],
237 | [output_image, src_points, trg_points]
238 | ).then(
239 | run_process,
240 | [canvas, input_image, output_image, src_points, trg_points, prompt, start_t, end_t, steps, noise_scale, data_path, method, seed],
241 | [output_image]
242 | )
243 |
244 | clear_all_button_r.click(
245 | clear_all_m,
246 | [gr_length],
247 | [canvas_r, input_image_r, output_image_r, selected_points_r, prompt, data_path, steps, noise_scale, src_points_r, trg_points_r]
248 | )
249 | canvas_r.clear(
250 | clear_all_m,
251 | [gr_length],
252 | [canvas_r, input_image_r, output_image_r, selected_points_r, prompt, data_path, steps, noise_scale, src_points_r, trg_points_r]
253 | )
254 | resize_button_r.click(
255 | clear_points_r,
256 | [canvas_r, selected_points_r],
257 | [input_image_r]
258 | ).then(
259 | resize_image,
260 | [canvas_r, gr_length, gr_gen_size],
261 | [canvas_r, input_image_r, output_image_r]
262 | )
263 | canvas_r.edit(
264 | draw_input_image_r,
265 | [canvas_r, selected_points_r],
266 | [input_image_r]
267 | ).then(
268 | preview_out_image_r,
269 | [canvas_r, selected_points_r],
270 | [output_image_r, src_points_r, trg_points_r]
271 | )
272 | input_image_r.upload(
273 | wrong_upload,
274 | outputs=[input_image_r]
275 | )
276 | input_image_r.select(
277 | add_point_r,
278 | [canvas_r, selected_points_r],
279 | [input_image_r]
280 | ).then(
281 | preview_out_image_r,
282 | [canvas_r, selected_points_r],
283 | [output_image_r, src_points_r, trg_points_r]
284 | )
285 | undo_point_button_r.click(
286 | undo_point_r,
287 | [canvas_r, selected_points_r],
288 | [input_image_r]
289 | ).then(
290 | preview_out_image_r,
291 | [canvas_r, selected_points_r],
292 | [output_image_r, src_points_r, trg_points_r]
293 | )
294 | clear_point_button_r.click(
295 | clear_points_r,
296 | [canvas_r, selected_points_r],
297 | [input_image_r]
298 | ).then(
299 | preview_out_image_r,
300 | [canvas_r, selected_points_r],
301 | [output_image_r, src_points_r, trg_points_r]
302 | )
303 | run_button_r.click(
304 | preview_out_image_r,
305 | [canvas_r, selected_points_r],
306 | [output_image_r, src_points_r, trg_points_r]
307 | ).then(
308 | run_process,
309 | [canvas_r, input_image_r, output_image_r, src_points_r, trg_points_r, prompt, start_t, end_t, steps, noise_scale, data_path, method, seed],
310 | [output_image_r]
311 | )
312 |
313 | demo.queue().launch(share=True, debug=True)
314 |
315 | if __name__ == '__main__':
316 | main()
--------------------------------------------------------------------------------