├── README.md ├── UI_GUIDE.md ├── assets ├── icon.svg ├── pipe.png ├── polygons.gif ├── region_points.gif ├── regions.gif ├── time.png └── ui_overview.png ├── region_utils ├── __init__.py ├── cycle_sde.py ├── drag.py ├── evaluator.py └── ui_utils.py ├── requirements.txt ├── run_eval.py └── ui.py /README.md: -------------------------------------------------------------------------------- 1 | # RegionDrag: Fast Region-Based Image Editing with Diffusion Models (ECCV 2024) 2 | **Jingyi Lu1, [Xinghui Li](https://xinghui-li.github.io/)2, [Kai Han](https://www.kaihan.org/)1**
3 | [1Visual AI Lab, The University of Hong Kong](https://visailab.github.io/index.html); [2Active Vision Lab, University of Oxford](https://www.robots.ox.ac.uk/ActiveVision/) 4 | 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pnq9t_1zZ8yL_Oba20eBLVZLp3glniBR?usp=sharing) 6 | page 7 | arXiv 8 | data 9 | 10 | 11 |
Time Pipe
12 | 13 | RegionDrag proposes to use pairs of **regions** instead of **points** (e.g. DragGAN, DragDiffusion) to drag image contents. Visit our [project page](https://visual-ai.github.io/regiondrag) for various region input examples. 14 | - **A region is equivalent to a large number of points**, providing richer input context and improving image consistency. 15 | - By using regions as input, we can reduce run time to **1.5 seconds** (close to 20-step SD image generation). 16 | 17 | During inference, the SD latent representations of the inputted image are extracted from 🔴 **RED** regions during inversion and mapped to 🔵 **BLUE** regions during denoising across multiple timesteps. 18 | 19 | ## Installation 20 | CUDA support is required to run our code, you can try our [Colab Demo](https://colab.research.google.com/drive/1pnq9t_1zZ8yL_Oba20eBLVZLp3glniBR?usp=sharing) for easy access to GPU resource.
21 | To locally install RegionDrag, run following using terminal: 22 | ``` 23 | git clone https://github.com/LuJingyi-John/RegionDrag.git 24 | cd RegionDrag 25 | pip install -r requirements.txt 26 | 27 | # Support InstantDrag Backbone Now (https://github.com/SNU-VGILab/InstantDrag) 28 | git clone https://github.com/SNU-VGILab/InstantDrag instantdrag 29 | ``` 30 | 31 | ## Run RegionDrag 32 | After installing the requirements, you can simply launch the user inferface through: 33 | ``` 34 | python3 ui.py 35 | ``` 36 | For detailed instructions to use our UI, check out our [User Guide](./UI_GUIDE.md). 37 | 38 | ## DragBench-SR & DragBench-DR 39 | To evaluate region-based editing, we introduce [DragBench-SR](https://github.com/ML-GSAI/SDE-Drag) and [DragBench-DR](https://github.com/Yujun-Shi/DragDiffusion/) (R is short for 'Region’), which are modified versions of DragBench-S (100 samples) and DragBench-D (205 samples). These benchmarks are consistent with their point-based counterparts but use regions instead of points to reflect user intention. You can download the dataset [HERE](https://drive.google.com/file/d/1rdi4Rqka8zqHTbPyhQYtFC2UdWvAeAGV/view?usp=sharing). 40 | 41 | 42 | ``` 43 | drag_data/ 44 | ├── dragbench-dr/ 45 | │ ├── animals/ 46 | │ │ ├── JH_2023-09-14-1820-16/ 47 | │ │ │ ├── original_image.png 48 | │ │ │ ├── user_drag.png 49 | │ │ │ ├── meta_data.pkl 50 | │ │ │ └── meta_data_region.pkl 51 | │ │ └── ... 52 | │ └── ... 53 | └── dragbench-sr/ 54 | ├── art_0/ 55 | │ ├── original_image.png 56 | │ ├── user_drag.png 57 | │ ├── meta_data.pkl 58 | │ └── meta_data_region.pkl 59 | └── ... 60 | ``` 61 | `meta_data.pkl` or `meta_data_region.pkl` include user interaction metadata in a dictionary format: 62 | 63 | ``` 64 | { 65 | 'prompt': text_prompt describing output image, 66 | 'points': list of points [(x1, y1), (x2, y2), ..., (xn, yn)], 67 | handle points: (x1,y1), (x3,y3), ..., target points: (x2,y2), (x4,y4), ..., 68 | 'mask': a binary mask specifying editing area, 69 | } 70 | ``` 71 | 72 | ## BibTeX 73 | ``` 74 | @inproceedings{lu2024regiondrag, 75 | author = {Jingyi Lu and Xinghui Li and Kai Han}, 76 | title = {RegionDrag: Fast Region-Based Image Editing with Diffusion Models}, 77 | booktitle = {European Conference on Computer Vision (ECCV)}, 78 | year = {2024}, 79 | } 80 | ``` 81 | 82 | ## Related links 83 | * [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/) 84 | * [DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing](https://github.com/Yujun-Shi/DragDiffusion/) 85 | * [MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing](https://ljzycmd.github.io/projects/MasaCtrl/) 86 | * [Emergent Correspondence from Image Diffusion](https://diffusionfeatures.github.io/) 87 | * [The Blessing of Randomness: SDE Beats ODE in General Diffusion-based Image Editing](https://github.com/ML-GSAI/SDE-Drag) 88 | * [InstantDrag: Improving Interactivity in Drag-based Image Editing](https://github.com/SNU-VGILab/InstantDrag) 89 | 90 | ## Acknowledgement 91 | Insightful discussions with Cheng Silin and Huang Xiaohu were instrumental in refining our methodology. The intuitive layout of the [DragDiffusion](https://github.com/Yujun-Shi/DragDiffusion/) project inspired our user interface design. Our SDE scheduler implementation builds upon the groundbreaking work by Shen Nie et al. in their [SDE-Drag](https://github.com/ML-GSAI/SDE-Drag) project. 92 | 93 | 94 | -------------------------------------------------------------------------------- /UI_GUIDE.md: -------------------------------------------------------------------------------- 1 | # RegionDrag: Fast Region-Based Image Editing with Diffusion Models (ECCV 2024) 2 | **Jingyi Lu†, [Xinghui Li‡](https://xinghui-li.github.io/), [Kai Han†](https://www.kaihan.org/)**
3 | [Visual AI Lab, The University of Hong Kong†](https://visailab.github.io/index.html); [Active Vision Lab, University of Oxford‡](https://www.robots.ox.ac.uk/ActiveVision/) 4 | 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pnq9t_1zZ8yL_Oba20eBLVZLp3glniBR?usp=sharing) 6 | page 7 | arXiv 8 | data 9 | 10 | 11 | ## Overview 12 | RegionDrag supports a variety of inputs. You can input regions or points to drag image contents from 🔴 **RED** to 🔵 **BLUE**. Below is a overview of different components in our UI. For detailed instructions to install RegionDrag, check out our [README](./README.md). 13 | 14 | Time 15 | 16 | ## Tips 17 | - Increasing the `Handle Noise Scale` can remove handle content. If it does not work, you can drag 🔴 some other contents to cover 🔵 the contents you would like to remove. 18 | - The image displayed in the `Results` column is a preview obtained from your inputs before using `Run Drag`. A better preview generally implies a better editing result. 19 | - If you find the preview image satisfactory, you can try changing the `Method` from `Encode then CP` to `CP then Encode`. 20 | 21 | ## Input pairs of regions 22 | - **Step 1:** Upload one image on the left, and click `Fit Canvas` to adjust size of image 23 | - **Step 2:** Add Regions (Draw mask on the left, and then click `Add Region`) 24 | - **Step 3:** Click `Run Drag` 25 | 26 | 27 | 28 | ## Input pairs of polygons 29 | - **Step 1:** Upload one image on the left, and click `Fit Canvas` to adjust size of image 30 | - **Step 2:** Click points on the middle image (You can select to input triangles or quadrilaterals) 31 | - **Step 3:** Click `Run Drag` 32 | 33 | 34 | 35 | ## Input regions and manipulate them by points 36 | - **Step 1:** Upload one image on the left, and click `Fit Canvas` to adjust size of image 37 | - **Step 2:** Draw masks to represent regions on the left 38 | - **Step 3:** Click points to control these regions in the middle 39 | - **Step 4:** Click `Run Drag` 40 | 41 | -------------------------------------------------------------------------------- /assets/icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 7 | 8 | 10 | 12 | 14 | 16 | 281 | 283 | 291 | 292 | 293 | 295 | 296 | 298 | 299 | 301 | 302 | 304 | 306 | 307 | 308 | 315 | 316 | 317 | 319 | 321 | 323 | 325 | 326 | 328 | 329 | 330 | 331 | 332 | 341 | 342 | 344 | 346 | 348 | 350 | 352 | 353 | 355 | 356 | 358 | 360 | 361 | 362 | 364 | 366 | 368 | 370 | 372 | 374 | 376 | 378 | 380 | 458 | 460 | 462 | 463 | 465 | 467 | 468 | 470 | 471 | 472 | 474 | 475 | 476 | 478 | 480 | 482 | 484 | 486 | 488 | 491 | 495 | 497 | 499 | 501 | 502 | 504 | 506 | 508 | 509 | 511 | 512 | 513 | 514 | 516 | 600 | 602 | 604 | 605 | 606 | 607 | 609 | 611 | 703 | 705 | 708 | 710 | 712 | 716 | 717 | 718 | 719 | 721 | 724 | 725 | 726 | 727 | 729 | 730 | 731 | 732 | 735 | 736 | 737 | 738 | 740 | 742 | 743 | 744 | 746 | 748 | 749 | 751 | 753 | 759 | 761 | 763 | 764 | 766 | 768 | 769 | 770 | 771 | 772 | 773 | 774 | 776 | 777 | 779 | 780 | 781 | 783 | 786 | 788 | 789 | 792 | 795 | 797 | 799 | 801 | 802 | 804 | 805 | 808 | 810 | 812 | 813 | 815 | 818 | 819 | 820 | 822 | 824 | 826 | 827 | 829 | 831 | 832 | 833 | 834 | 836 | 838 | 839 | 841 | 843 | 844 | 845 | 846 | 848 | 849 | 850 | 852 | 853 | 854 | 855 | 856 | -------------------------------------------------------------------------------- /assets/pipe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/pipe.png -------------------------------------------------------------------------------- /assets/polygons.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/polygons.gif -------------------------------------------------------------------------------- /assets/region_points.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/region_points.gif -------------------------------------------------------------------------------- /assets/regions.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/regions.gif -------------------------------------------------------------------------------- /assets/time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/time.png -------------------------------------------------------------------------------- /assets/ui_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/assets/ui_overview.png -------------------------------------------------------------------------------- /region_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/RegionDrag/5cf492212c24edae5ca2376579f2e81f46fc322f/region_utils/__init__.py -------------------------------------------------------------------------------- /region_utils/cycle_sde.py: -------------------------------------------------------------------------------- 1 | # This file is developed based on the work of [ML-GSAI/SDE-Drag], 2 | # which can be found at https://github.com/ML-GSAI/SDE-Drag 3 | 4 | import torch 5 | import random 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision import transforms 9 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler 10 | 11 | def load_model(version="v1-5", torch_device='cuda', torch_dtype=torch.float16, verbose=True): 12 | pipe_paths = { 13 | 'v1-5' : "runwayml/stable-diffusion-v1-5", 14 | 'v2-1' : "stabilityai/stable-diffusion-2-1", 15 | 'xl' : "stabilityai/stable-diffusion-xl-base-1.0" 16 | } 17 | pipe_path = pipe_paths.get(version, pipe_paths['v1-5']) 18 | pipe = StableDiffusionPipeline if version in ['v1-5', 'v2-1'] else StableDiffusionXLPipeline 19 | 20 | if verbose: 21 | print(f'Loading model from {pipe_path}.') 22 | pipe = pipe.from_pretrained(pipe_path, torch_dtype=torch_dtype).to(torch_device) 23 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 24 | 25 | # IP-Adaptor 26 | if version in ['v1-5', 'v2-1']: 27 | subfolder, weight_name, ip_adapter_scale = "models", "ip-adapter-plus_sd15.bin", 0.5 28 | else: 29 | subfolder, weight_name, ip_adapter_scale = "sdxl_models", "ip-adapter_sdxl.bin", 0.6 30 | pipe.load_ip_adapter("h94/IP-Adapter", subfolder=subfolder, weight_name=weight_name) 31 | pipe.set_ip_adapter_scale(ip_adapter_scale) 32 | 33 | tokenizer_2 = pipe.tokenizer_2 if version == 'xl' else None 34 | text_encoder_2 = pipe.text_encoder_2 if version == 'xl' else None 35 | return pipe.vae, pipe.tokenizer, pipe.text_encoder, pipe.unet, pipe.scheduler, pipe.feature_extractor, pipe.image_encoder, tokenizer_2, text_encoder_2 36 | 37 | @torch.no_grad() 38 | def get_text_embed(prompt: list, tokenizer, text_encoder, tokenizer_2=None, text_encoder_2=None, torch_device='cuda'): 39 | text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 40 | prompt_embeds = text_encoder(text_input.input_ids.to(torch_device), output_hidden_states=True) 41 | pooled_prompt_embeds, prompt_embeds = prompt_embeds[0], prompt_embeds.hidden_states[-2] 42 | 43 | if tokenizer_2 is not None and text_encoder_2 is not None: 44 | text_input = tokenizer_2(prompt, padding="max_length", max_length=tokenizer_2.model_max_length, truncation=True, return_tensors="pt") 45 | prompt_embeds_2 = text_encoder_2(text_input.input_ids.to(torch_device), output_hidden_states=True) 46 | pooled_prompt_embeds, prompt_embeds_2 = prompt_embeds_2[0], prompt_embeds_2.hidden_states[-2] 47 | prompt_embeds = torch.cat([prompt_embeds, prompt_embeds_2], dim=-1) 48 | 49 | return pooled_prompt_embeds, prompt_embeds 50 | 51 | @torch.no_grad() 52 | def get_img_latent(image, vae, torch_device='cuda', dtype=torch.float16, size=None): 53 | # upcast vae for sdxl, attention blocks can be in torch.float16 54 | upcast_dtype = torch.float32 if 'xl-base-1.0' in vae.config._name_or_path and dtype == torch.float16 else dtype 55 | if dtype == torch.float16: 56 | vae = vae.to(upcast_dtype) 57 | for module in [vae.post_quant_conv, vae.decoder.conv_in, vae.decoder.mid_block]: 58 | module = module.to(dtype) 59 | 60 | image = Image.open(image).convert('RGB') if isinstance(image, str) else Image.fromarray(image) 61 | image = image.resize(size) if size else image 62 | image = transforms.ToTensor()(image).unsqueeze(0).to(torch_device, upcast_dtype) 63 | latents = vae.encode(image * 2 - 1).latent_dist.sample() * 0.18215 64 | return latents.to(dtype) 65 | 66 | def set_seed(seed): 67 | torch.manual_seed(seed) 68 | random.seed(seed) 69 | np.random.seed(seed) 70 | 71 | torch.backends.cudnn.deterministic = True 72 | torch.backends.cudnn.benchmark = False 73 | 74 | class Sampler(): 75 | def __init__(self, unet, scheduler, num_steps=100): 76 | scheduler.set_timesteps(num_steps) 77 | self.num_inference_steps = num_steps 78 | self.num_train_timesteps = len(scheduler) 79 | 80 | self.alphas = scheduler.alphas 81 | self.alphas_cumprod = scheduler.alphas_cumprod 82 | 83 | self.final_alpha_cumprod = torch.tensor(1.0) 84 | self.initial_alpha_cumprod = torch.tensor(1.0) 85 | 86 | self.unet = unet 87 | 88 | def get_eps(self, img, timestep, guidance_scale, text_embeddings, lora_scale=None, added_cond_kwargs=None): 89 | guidance_scale = max(1, guidance_scale) 90 | 91 | text_embeddings = text_embeddings if guidance_scale > 1. else text_embeddings[-1:] 92 | latent_model_input = torch.cat([img] * 2) if guidance_scale > 1. else img 93 | cross_attention_kwargs = None if lora_scale is None else {"scale": lora_scale} 94 | 95 | if guidance_scale == 1. and added_cond_kwargs is not None: 96 | added_cond_kwargs = {k: v[-1:] for k, v in added_cond_kwargs.items()} 97 | 98 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs).sample 99 | 100 | if guidance_scale > 1.: 101 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 102 | elif guidance_scale == 1.: 103 | noise_pred_text = noise_pred 104 | noise_pred_uncond = 0. 105 | else: 106 | raise NotImplementedError(guidance_scale) 107 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 108 | 109 | return noise_pred 110 | 111 | def sample(self, timestep, sample, guidance_scale, text_embeddings, sde=False, noise=None, eta=1., lora_scale=None, added_cond_kwargs=None): 112 | eps = self.get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale, added_cond_kwargs) 113 | 114 | prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps 115 | 116 | alpha_prod_t = self.alphas_cumprod[timestep] 117 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod 118 | 119 | beta_prod_t = 1 - alpha_prod_t 120 | 121 | sigma_t = eta * ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5) * (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5) if sde else 0 122 | 123 | pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5) 124 | pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t ** 2) ** (0.5) 125 | 126 | noise = torch.randn_like(sample, device=sample.device) if noise is None else noise 127 | img = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise 128 | 129 | return img 130 | 131 | def forward_sde(self, timestep, sample, guidance_scale, text_embeddings, eta=1., lora_scale=None, added_cond_kwargs=None): 132 | prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps 133 | 134 | alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod 135 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] 136 | 137 | beta_prod_t_prev = 1 - alpha_prod_t_prev 138 | 139 | x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** (0.5) * torch.randn_like(sample, device=sample.device) 140 | eps = self.get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale, added_cond_kwargs) 141 | 142 | sigma_t_prev = eta * ((1 - alpha_prod_t) / (1 - alpha_prod_t_prev)) ** (0.5) * (1 - alpha_prod_t_prev / alpha_prod_t) ** (0.5) 143 | 144 | pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5) 145 | pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev ** 2) ** (0.5) 146 | 147 | noise = (sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps) / sigma_t_prev 148 | 149 | return x_prev, noise 150 | 151 | def forward_ode(self, timestep, sample, guidance_scale, text_embeddings, lora_scale=None, added_cond_kwargs=None): 152 | prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps 153 | 154 | alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod 155 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] 156 | 157 | beta_prod_t = 1 - alpha_prod_t 158 | 159 | eps = self.get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale, added_cond_kwargs) 160 | pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5) 161 | pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * eps 162 | 163 | img = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 164 | 165 | noise = None 166 | return img, noise -------------------------------------------------------------------------------- /region_utils/drag.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import numpy as np 4 | import pickle 5 | from PIL import Image 6 | import cv2 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from .cycle_sde import Sampler, get_img_latent, get_text_embed, load_model, set_seed 13 | 14 | # select version from v1-5 (recommended), v2-1, xl 15 | sd_version = 'v1-5' 16 | 17 | # --- To include: InstantDrag (https://github.com/SNU-VGILab/InstantDrag) --- # 18 | import sys 19 | sys.path.append('instantdrag/') 20 | if not os.path.exists('instantdrag/utils/__init__.py'): 21 | open('instantdrag/utils/__init__.py', 'a').close() 22 | 23 | from huggingface_hub import snapshot_download 24 | from instantdrag.demo.demo_utils import InstantDragPipeline 25 | os.makedirs("./checkpoints", exist_ok=True) 26 | snapshot_download("alex4727/InstantDrag", local_dir="./checkpoints") 27 | 28 | def scale_schedule(begin, end, n, length, type='linear'): 29 | if type == 'constant': 30 | return end 31 | elif type == 'linear': 32 | return begin + (end - begin) * n / length 33 | elif type == 'cos': 34 | factor = (1 - math.cos(n * math.pi / length)) / 2 35 | return (1 - factor) * begin + factor * end 36 | else: 37 | raise NotImplementedError(type) 38 | 39 | def get_meta_data(meta_data_path): 40 | with open(meta_data_path, 'rb') as file: 41 | meta_data = pickle.load(file) 42 | prompt = meta_data['prompt'] 43 | mask = meta_data['mask'] 44 | points = meta_data['points'] 45 | source = points[0:-1:2] 46 | target = points[1::2] 47 | return prompt, mask, source, target 48 | 49 | def get_drag_data(data_path): 50 | ori_image_path = os.path.join(data_path, 'original_image.png') 51 | meta_data_path = os.path.join(data_path, 'meta_data_region.pkl') 52 | 53 | original_image = np.array(Image.open(ori_image_path)) 54 | prompt, mask, source, target = get_meta_data(meta_data_path) 55 | 56 | return { 57 | 'ori_image' : original_image, 'preview' : original_image, 'prompt' : prompt, 58 | 'mask' : mask, 'source' : np.array(source), 'target' : np.array(target) 59 | } 60 | 61 | def reverse_and_repeat_every_n_elements(lst, n, repeat=1): 62 | """ 63 | Reverse every n elements in a given list, then repeat the reversed segments 64 | the specified number of times. 65 | Example: 66 | >>> reverse_and_repeat_every_n_elements([1, 2, 3, 4, 5, 6, 7, 8, 9], 3, 2) 67 | [3, 2, 1, 3, 2, 1, 6, 5, 4, 6, 5, 4, 9, 8, 7, 9, 8, 7] 68 | """ 69 | if not lst or n < 1: 70 | return lst 71 | return [element for i in range(0, len(lst), n) for _ in range(repeat) for element in reversed(lst[i:i+n])] 72 | 73 | def get_border_points(points): 74 | x_max, y_max = np.amax(points, axis=0) 75 | mask = np.zeros((y_max+1, x_max+1), np.uint8) 76 | mask[points[:, 1], points[:, 0]] = 1 77 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 78 | border_points = np.concatenate([contour[:, 0, :] for contour in contours], axis=0) 79 | return border_points 80 | 81 | def postprocess(vae, latent, ori_image, mask): 82 | dtype = latent.dtype 83 | upcast_dtype = torch.float32 if 'xl-base-1.0' in vae.config._name_or_path and dtype == torch.float16 else dtype 84 | H, W = ori_image.shape[:2] 85 | 86 | if dtype == torch.float16: 87 | vae = vae.to(upcast_dtype) 88 | for module in [vae.post_quant_conv, vae.decoder.conv_in, vae.decoder.mid_block]: 89 | module = module.to(dtype) 90 | 91 | image = vae.decode(latent / 0.18215).sample / 2 + 0.5 92 | image = (image.clamp(0, 1).permute(0, 2, 3, 1)[0].cpu().numpy() * 255).astype(np.uint8) 93 | image = cv2.resize(image, (W, H)) 94 | 95 | if not np.all(mask == 1): 96 | image = np.where(mask[:, :, None], image, ori_image) 97 | 98 | return image 99 | 100 | def copy_and_paste(source_latents, target_latents, source, target): 101 | target_latents[0, :, target[:, 1], target[:, 0]] = source_latents[0, :, source[:, 1], source[:, 0]] 102 | return target_latents 103 | 104 | def blur_source(latents, noise_scale, source): 105 | img_scale = (1 - noise_scale ** 2) ** (0.5) if noise_scale < 1 else 0 106 | latents[0, :, source[:, 1], source[:, 0]] = latents[0, :, source[:, 1], source[:, 0]] * img_scale + \ 107 | torch.randn_like(latents[0, :, source[:, 1], source[:, 0]]) * noise_scale 108 | return latents 109 | 110 | def ip_encode_image(feature_extractor, image_encoder, image): 111 | dtype = next(image_encoder.parameters()).dtype 112 | device = next(image_encoder.parameters()).device 113 | 114 | image = feature_extractor(image, return_tensors="pt").pixel_values.to(device=device, dtype=dtype) 115 | image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2] 116 | uncond_image_enc_hidden_states = image_encoder( 117 | torch.zeros_like(image), output_hidden_states=True 118 | ).hidden_states[-2] 119 | image_embeds = torch.stack([uncond_image_enc_hidden_states, image_enc_hidden_states], dim=0) 120 | 121 | return [image_embeds] 122 | 123 | def forward(scheduler, sampler, steps, start_t, latent, text_embeddings, added_cond_kwargs, progress=tqdm, sde=True): 124 | forward_func = sampler.forward_sde if sde else sampler.forward_ode 125 | hook_latents = [latent,]; noises = []; cfg_scales = [] 126 | start_t = int(start_t * steps) 127 | 128 | for index, t in enumerate(progress(scheduler.timesteps[(steps - start_t):].flip(dims=[0])), start=1): 129 | cfg_scale = scale_schedule(1, 1, index, steps, type='linear') 130 | latent, noise = forward_func(t, latent, cfg_scale, text_embeddings, added_cond_kwargs=added_cond_kwargs) 131 | hook_latents.append(latent); noises.append(noise); cfg_scales.append(cfg_scale) 132 | 133 | return hook_latents, noises, cfg_scales 134 | 135 | def backward(scheduler, sampler, steps, start_t, end_t, noise_scale, hook_latents, noises, cfg_scales, mask, text_embeddings, added_cond_kwargs, blur, source, target, progress=tqdm, latent=None, sde=True): 136 | start_t = int(start_t * steps) 137 | end_t = int(end_t * steps) 138 | 139 | latent = hook_latents[-1].clone() if latent is None else latent 140 | latent = blur_source(latent, noise_scale, blur) 141 | 142 | for t in progress(scheduler.timesteps[(steps-start_t- 1):-1]): 143 | hook_latent = hook_latents.pop() 144 | latent = copy_and_paste(hook_latent, latent, source, target) if t >= end_t else latent 145 | latent = torch.where(mask == 1, latent, hook_latent) 146 | latent = sampler.sample(t, latent, cfg_scales.pop(), text_embeddings, sde=sde, noise=noises.pop(), added_cond_kwargs=added_cond_kwargs) 147 | return latent 148 | 149 | def drag(drag_data, steps, start_t, end_t, noise_scale, seed, progress=tqdm, method='Encode then CP', save_path=''): 150 | set_seed(seed) 151 | device = 'cuda' 152 | ori_image, preview, prompt, mask, source, target = drag_data.values() 153 | 154 | if method in ('Encode then CP', 'CP then Encode'): 155 | global vae, tokenizer, text_encoder, unet, scheduler, feature_extractor, image_encoder, tokenizer_2, text_encoder_2 156 | if 'vae' not in globals(): 157 | vae, tokenizer, text_encoder, unet, scheduler, feature_extractor, image_encoder, tokenizer_2, text_encoder_2 = load_model(sd_version) 158 | 159 | def copy_key_hook(module, input, output): 160 | keys.append(output) 161 | def copy_value_hook(module, input, output): 162 | values.append(output) 163 | def paste_key_hook(module, input, output): 164 | output[:] = keys.pop() 165 | def paste_value_hook(module, input, output): 166 | output[:] = values.pop() 167 | 168 | def register(do='copy'): 169 | do_copy = do == 'copy' 170 | key_hook, value_hook = (copy_key_hook, copy_value_hook) if do_copy else (paste_key_hook, paste_value_hook) 171 | key_handlers = []; value_handlers = [] 172 | for block in (*sampler.unet.down_blocks, sampler.unet.mid_block, *sampler.unet.up_blocks): 173 | if not hasattr(block, 'attentions'): 174 | continue 175 | for attention in block.attentions: 176 | for tb in attention.transformer_blocks: 177 | key_handlers.append(tb.attn1.to_k.register_forward_hook(key_hook)) 178 | value_handlers.append(tb.attn1.to_v.register_forward_hook(value_hook)) 179 | return key_handlers, value_handlers 180 | 181 | def unregister(*handlers): 182 | for handler in handlers: 183 | handler.remove() 184 | torch.cuda.empty_cache() 185 | 186 | sde = encode_then_cp = method == 'Encode then CP' 187 | source = torch.from_numpy(source).to(device) if isinstance(source, np.ndarray) else source.to(device) 188 | target = torch.from_numpy(target).to(device) if isinstance(target, np.ndarray) else target.to(device) 189 | source = source // 8; target = target // 8 # from img scale to latent scale 190 | 191 | if encode_then_cp: 192 | blur_pts = source; copy_pts = source 193 | else: 194 | blur_pts = torch.cat([torch.from_numpy(get_border_points(target.cpu().numpy())).to(device), source], dim=0) 195 | copy_pts = target 196 | paste_pts = target 197 | 198 | latent = get_img_latent(ori_image, vae) 199 | preview_latent = get_img_latent(preview, vae) if not encode_then_cp else None 200 | sampler = Sampler(unet=unet, scheduler=scheduler, num_steps=steps) 201 | 202 | with torch.no_grad(): 203 | neg_pooled_prompt_embeds, neg_prompt_embeds = get_text_embed("", tokenizer, text_encoder, tokenizer_2, text_encoder_2) 204 | neg_prompt_embeds = neg_prompt_embeds if sd_version == 'xl' else neg_pooled_prompt_embeds 205 | pooled_prompt_embeds, prompt_embeds = get_text_embed(prompt, tokenizer, text_encoder, tokenizer_2, text_encoder_2) 206 | prompt_embeds = prompt_embeds if sd_version == 'xl' else pooled_prompt_embeds 207 | prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) 208 | pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 209 | 210 | image_embeds = ip_encode_image(feature_extractor, image_encoder, ori_image) 211 | 212 | H, W = ori_image.shape[:2] 213 | add_time_ids = torch.tensor([[H, W, 0, 0, H, W]]).to(prompt_embeds).repeat(2, 1) 214 | added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} if sd_version == 'xl' else {} 215 | added_cond_kwargs["image_embeds"] = image_embeds 216 | 217 | mask_pt = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(device) 218 | mask_pt = F.interpolate(mask_pt, size=latent.shape[2:]).expand_as(latent) 219 | 220 | if not encode_then_cp: 221 | hook_latents, noises, cfg_scales = forward(scheduler, sampler, steps, start_t, preview_latent, prompt_embeds, added_cond_kwargs, progress=progress, sde=sde) 222 | 223 | keys = []; values = [] 224 | key_handlers, value_handlers = register(do='copy') 225 | if encode_then_cp: 226 | hook_latents, noises, cfg_scales = forward(scheduler, sampler, steps, start_t, latent, prompt_embeds, added_cond_kwargs, progress=progress, sde=sde) 227 | start_latent = None 228 | else: 229 | start_latent = forward(scheduler, sampler, steps, start_t, latent, prompt_embeds, added_cond_kwargs, progress=progress, sde=sde)[0][-1] 230 | unregister(*key_handlers, *value_handlers) 231 | 232 | keys = reverse_and_repeat_every_n_elements(keys, n=len(key_handlers)) 233 | values = reverse_and_repeat_every_n_elements(values, n=len(value_handlers)) 234 | 235 | key_handlers, value_handlers = register(do='paste') 236 | latent = backward(scheduler, sampler, steps, start_t, end_t, noise_scale, hook_latents, noises, cfg_scales, mask_pt, prompt_embeds, added_cond_kwargs, blur_pts, copy_pts, paste_pts, latent=start_latent, progress=progress, sde=sde) 237 | unregister(*key_handlers, *value_handlers) 238 | 239 | image = postprocess(vae, latent, ori_image, mask) 240 | 241 | elif method == 'InstantDrag': 242 | global instant_pipe 243 | if 'instant_pipe' not in globals(): 244 | instant_pipe = InstantDragPipeline(seed, 'cuda', torch.float16) 245 | flowgen_ckpt = next((m for m in sorted(os.listdir("checkpoints/")) if "flowgen" in m), None) 246 | flowdiffusion_ckpt = next(f for f in sorted(os.listdir("checkpoints/")) if "flowdiffusion" in f) 247 | 248 | print('Unused parameters in utils.drag function: input, start_t, end_t, noise_scale, progress') 249 | selected_points = [point.tolist() for pair in zip(source, target) for point in pair] 250 | ori_H, ori_W = ori_image.shape[:2]; new_H, new_W = 512, 512 251 | selected_points = torch.tensor(selected_points) / torch.tensor([ori_W, ori_H]) * torch.tensor([new_W, new_H]) 252 | image_guidance, flow_guidance, flowgen_output_scale = 1.5, 1.5, -1.0 253 | image = instant_pipe.run(cv2.resize(ori_image, (new_W, new_H)), selected_points.tolist(), flowgen_ckpt, flowdiffusion_ckpt, image_guidance, 254 | flow_guidance, flowgen_output_scale, steps, save_results=False) 255 | image = cv2.resize(image, (ori_W, ori_H)) 256 | 257 | else: 258 | raise ValueError('Select method from InstantDrag, Encode then CP, CP then Encode') 259 | 260 | if save_path: 261 | os.makedirs(save_path, exist_ok=True) 262 | counter = 0 263 | file_root, file_extension = os.path.splitext('dragged_image.png') 264 | while True: 265 | test_name = f"{file_root} ({counter}){file_extension}" if counter != 0 else 'dragged_image.png' 266 | full_path = os.path.join(save_path, test_name) 267 | if not os.path.exists(full_path): 268 | Image.fromarray(image).save(full_path) 269 | break 270 | counter += 1 271 | torch.cuda.empty_cache() 272 | return image -------------------------------------------------------------------------------- /region_utils/evaluator.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import cv2 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import lpips 11 | from diffusers import StableDiffusionPipeline 12 | from transformers import AutoModel 13 | 14 | warnings.filterwarnings(action='ignore', category=UserWarning) 15 | 16 | def plot_matching_result(src_image, trg_image, src_points, trg_points, pred_trg_points, output_path=None, border_width=10): 17 | """ 18 | Used to visualize dragging effect. 19 | """ 20 | if src_points.shape != trg_points.shape or src_points.shape != pred_trg_points.shape: 21 | raise ValueError(f"points arrays must have the same shapes, got Source:{src_points.shape}, Target:{trg_points.shape}, Predicted:{pred_trg_points.shape}") 22 | if src_image.shape[0] != trg_image.shape[0]: 23 | raise ValueError(f"Source image and target image must have same height, got Source:{src_image.shape}, Target:{trg_image.shape}") 24 | 25 | # Create a border and combine images 26 | border = np.ones((src_image.shape[0], border_width, 3), dtype=np.uint8) * 255 27 | combined_image = np.concatenate((src_image, border, trg_image), axis=1) 28 | 29 | # Adjust target and predicted target points by the width of the source image and the border 30 | trg_points_adj = trg_points + np.array([src_image.shape[1] + border_width, 0]) 31 | pred_trg_points_adj = pred_trg_points + np.array([src_image.shape[1] + border_width, 0]) 32 | 33 | # Create the plot 34 | fig, ax = plt.subplots(figsize=(10, 5)) 35 | ax.imshow(combined_image.astype(np.uint8)) 36 | 37 | # Draw arrows and points 38 | for src_pt, trg_pt, pred_trg_pt in zip(src_points, trg_points_adj, pred_trg_points_adj): 39 | ax.scatter(src_pt[0], src_pt[1], color='red', s=20) 40 | ax.scatter(trg_pt[0], trg_pt[1], color='blue', s=20) 41 | ax.scatter(pred_trg_pt[0], pred_trg_pt[1], color='red', s=20) 42 | 43 | # Draw arrows from src to trg 44 | ax.arrow(src_pt[0], src_pt[1], trg_pt[0] - src_pt[0], trg_pt[1] - src_pt[1], 45 | head_width=5, head_length=10, fc='white', ec='white', length_includes_head=True, lw=2, alpha=0.8) 46 | 47 | # Draw arrows from pred_trg to trg 48 | ax.arrow(pred_trg_pt[0], pred_trg_pt[1], trg_pt[0] - pred_trg_pt[0], trg_pt[1] - pred_trg_pt[1], 49 | head_width=5, head_length=10, fc='white', ec='white', length_includes_head=True, lw=2, alpha=0.8) 50 | 51 | ax.axis('off') 52 | 53 | # Save or show the image 54 | if output_path is not None: 55 | directory = os.path.dirname(output_path) 56 | if directory != '': 57 | os.makedirs(directory, exist_ok=True) 58 | plt.savefig(output_path, bbox_inches='tight', dpi=300) # Save with higher resolution 59 | 60 | plt.show() 61 | plt.close() 62 | 63 | return fig if output_path is None else output_path 64 | 65 | 66 | def create_mask(src_points, trg_points, img_size): 67 | """ 68 | Creates a batch of masks based on the distance of image pixels to batches of given points. 69 | 70 | Args: 71 | src_points (torch.Tensor): The source points coordinates of shape (N, 2) [x, y]. 72 | trg_points (torch.Tensor): The target points coordinates of shape (N, 2) [x, y]. 73 | img_size (tuple): The size of the image (height, width). 74 | 75 | Returns: 76 | torch.Tensor: A batch of boolean masks where True indicates the pixel is within the distance for each point pair. 77 | """ 78 | src_points = src_points.float() 79 | trg_points = trg_points.float() 80 | 81 | h, w = img_size 82 | point_distances = ((src_points - trg_points).norm(dim=1) / (2**0.5)).clamp(min=5) # Multiplying by 1/sqrt(2) 83 | 84 | y_indices, x_indices = torch.meshgrid( 85 | torch.arange(h, device=src_points.device), 86 | torch.arange(w, device=src_points.device), 87 | indexing="ij" 88 | ) 89 | 90 | # Expand grid to match the batch size (y_indices, x_indices: shape [N, H, W]) 91 | y_indices = y_indices.expand(src_points.size(0), -1, -1) 92 | x_indices = x_indices.expand(src_points.size(0), -1, -1) 93 | 94 | distance_to_p0 = ((x_indices - src_points[:, None, None, 0])**2 + (y_indices - src_points[:, None, None, 1])**2).sqrt() 95 | distance_to_p1 = ((x_indices - trg_points[:, None, None, 0])**2 + (y_indices - trg_points[:, None, None, 1])**2).sqrt() 96 | masks = (distance_to_p0 < point_distances[:, None, None]) | (distance_to_p1 < point_distances[:, None, None]) # (N, H, W) 97 | 98 | return masks 99 | 100 | def nn_get_matches(src_featmaps, trg_featmaps, query, l2_norm=True, mask=None): 101 | ''' 102 | Find the nearest neighbour matches for a given query from source feature maps in target feature maps. 103 | 104 | Args: 105 | src_featmaps (torch.Tensor): Source feature map with shape (1 x C x H x W). 106 | trg_featmaps (torch.Tensor): Target feature map with shape (1 x C x H x W). 107 | query (torch.Tensor): (x, y) coordinates with shape (N x 2), must be in the range of src_featmaps. 108 | l2_norm (bool): If True, apply L2 normalization to features. 109 | mask (torch.Tensor): Optional mask with shape (N x H x W). 110 | 111 | Returns: 112 | torch.Tensor: (x, y) coordinates of the top matches with shape (N x 2). 113 | ''' 114 | # Extract features from the source feature map at the query points 115 | _, c, h, w = src_featmaps.shape # (1, C, H, W) 116 | query = query.long() 117 | src_feat = src_featmaps[0, :, query[:, 1], query[:, 0]] # (C, N) 118 | 119 | if l2_norm: 120 | src_feat = F.normalize(src_feat, p=2, dim=0) 121 | trg_featmaps = F.normalize(trg_featmaps, p=2, dim=1) 122 | 123 | trg_featmaps = trg_featmaps.view(c, -1) # flatten (C, H*W) 124 | similarity = torch.mm(src_feat.t(), trg_featmaps) # similarity shape: (N, H*W) 125 | 126 | if mask is not None: 127 | mask = mask.view(-1, h * w) # mask shape: (N, H*W) 128 | similarity = torch.where(mask, similarity, torch.full_like(similarity, -torch.inf)) 129 | 130 | # Get the indices of the best matches 131 | best_match_idx = similarity.argmax(dim=-1) # best_match_idx shape: (N,) 132 | 133 | # Convert flat indices to 2D coordinates 134 | y_coords = best_match_idx // w # y_coords shape: (N,) 135 | x_coords = best_match_idx % w # x_coords shape: (N,) 136 | coords = torch.stack((x_coords, y_coords), dim=1) # coords shape: (N, 2) 137 | 138 | return coords.float() # Output shape: (N, 2) 139 | 140 | class SDFeaturizer(StableDiffusionPipeline): 141 | """Used to extract SD2-1 feature for semantic point matching (DIFT).""" 142 | @torch.no_grad() 143 | def __call__( 144 | self, 145 | img_tensor, 146 | t=261, 147 | ensemble=8, 148 | prompt=None, 149 | prompt_embeds=None 150 | ): 151 | device = self._execution_device 152 | latents = self.vae.encode(img_tensor).latent_dist.mode() * self.vae.config.scaling_factor 153 | latents = latents.expand(ensemble, -1, -1, -1) 154 | t = torch.tensor(t, dtype=torch.long, device=device) 155 | noise = torch.randn_like(latents) 156 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 157 | 158 | if prompt_embeds is None: 159 | if prompt is None: 160 | prompt = "" 161 | prompt_embeds = self.encode_prompt( 162 | prompt=prompt, 163 | device=device, 164 | num_images_per_prompt=1, 165 | do_classifier_free_guidance=False 166 | )[0] 167 | prompt_embeds = prompt_embeds.expand(ensemble, -1, -1) 168 | 169 | # Cache output of second upblock of unet 170 | unet_feature = [] 171 | def hook(module, input, output): 172 | unet_feature.clear() 173 | unet_feature.append(output) 174 | handle = list(self.unet.children())[4][1].register_forward_hook(hook=hook) 175 | self.unet(latents_noisy, t, prompt_embeds) 176 | handle.remove() 177 | 178 | return unet_feature[0].mean(dim=0, keepdim=True) 179 | 180 | class DragEvaluator: 181 | def __init__(self): 182 | self.clip_loaded = False 183 | self.dino_loaded = False 184 | self.sd_loaded = False 185 | self.lpips_loaded = False 186 | 187 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 188 | self.dtype = torch.float16 189 | 190 | def load_dino(self): 191 | if not self.dino_loaded: 192 | dino_path = 'facebook/dinov2-base' 193 | self.dino_model = AutoModel.from_pretrained(dino_path).to(self.device).to(self.dtype) 194 | self.dino_loaded = True 195 | 196 | def load_sd(self): 197 | if not self.sd_loaded: 198 | sd_path = 'stabilityai/stable-diffusion-2-1' 199 | self.sd_feat = SDFeaturizer.from_pretrained(sd_path, torch_dtype=self.dtype).to(self.device) 200 | self.sd_loaded = True 201 | 202 | def load_lpips(self): 203 | if not self.lpips_loaded: 204 | self.loss_fn_alex = lpips.LPIPS(net='alex').to(self.device).to(self.dtype) 205 | self.lpips_loaded = True 206 | 207 | def preprocess_image(self, image): 208 | image = torch.from_numpy(np.array(image)).float() / 127.5 - 1 # Normalize to [-1, 1] 209 | image = image.unsqueeze(0).permute(0, 3, 1, 2) # Rearrange to (1, C, H, W) 210 | return image.to(self.device).to(self.dtype) 211 | 212 | @torch.no_grad() 213 | def compute_lpips(self, image1, image2): 214 | """ 215 | Learned Perceptual Image Patch Similarity (LPIPS) metric. (https://richzhang.github.io/PerceptualSimilarity/) 216 | """ 217 | self.load_lpips() 218 | 219 | image1 = self.preprocess_image(image1) 220 | image2 = self.preprocess_image(image2) 221 | image1 = F.interpolate(image1, (224,224), mode='bilinear') 222 | image2 = F.interpolate(image2, (224,224), mode='bilinear') 223 | 224 | return self.loss_fn_alex(image1, image2).item() 225 | 226 | def encode_image(self, image, method, prompt=None): 227 | if method == 'dino': 228 | featmap = self.dino_model(image).last_hidden_state[:, 1:, :].permute(0, 2, 1) 229 | featmap = featmap.view(1, -1, 60, 60) # 60 = 840 / 14 230 | elif method == 'sd': 231 | featmap = self.sd_feat(image, prompt=prompt) 232 | else: 233 | raise NotImplementedError('Only SD and DINO supported.') 234 | return featmap 235 | 236 | @torch.no_grad() 237 | def compute_distance(self, src_image, trg_image, src_kps, trg_kps, method, prompt=None, plot_path=None): 238 | """ Mean Distance Metric """ 239 | if method == 'dino': 240 | self.load_dino() 241 | elif method == 'sd': 242 | self.load_sd() 243 | else: 244 | raise NotImplementedError('Only SD and DINO supported.') 245 | 246 | src_kps = torch.tensor(src_kps, device=self.device).to(torch.long) # (N, 2) N points 247 | trg_kps = torch.tensor(trg_kps, device=self.device).to(torch.long) # (N, 2) N points 248 | 249 | # Resize target image and scale target points when necessary 250 | if src_image.shape != trg_image.shape: 251 | src_img_h, src_img_w, _ = src_image.shape 252 | trg_img_h, trg_img_w, _ = trg_image.shape 253 | trg_image = cv2.resize(trg_image, (src_img_w, src_img_h)) 254 | 255 | trg_kps = trg_kps * torch.tensor([src_img_w, src_img_h], device=self.device) 256 | trg_kps = trg_kps / torch.tensor([trg_img_w, trg_img_h], device=self.device) 257 | trg_kps = trg_kps.to(torch.long) 258 | 259 | image_h, image_w, _ = src_image.shape 260 | image_size = 840 if method == 'dino' else 768 261 | 262 | source_image = self.preprocess_image(src_image) # 1, 3, H, W 263 | target_image = self.preprocess_image(trg_image) # 1, 3, H, W 264 | source_image = F.interpolate(source_image, size=(image_size, image_size)) # 1, 3, img_size, img_size 265 | target_image = F.interpolate(target_image, size=(image_size, image_size)) # 1, 3, img_size, img_size 266 | 267 | src_featmap = self.encode_image(source_image, method=method, prompt=prompt) 268 | src_featmap = F.interpolate(src_featmap, size=(image_h, image_w)) 269 | 270 | trg_featmap = self.encode_image(target_image, method=method, prompt=prompt) 271 | trg_featmap = F.interpolate(trg_featmap, size=(image_h, image_w)) 272 | 273 | mask = create_mask(src_kps, trg_kps, (image_h, image_w)) 274 | pred_trg_kps = nn_get_matches(src_featmap, trg_featmap, src_kps, l2_norm=True, mask=mask) 275 | 276 | distance = trg_kps - pred_trg_kps 277 | distance[:, 0] /= image_w; distance[:, 1] /= image_h 278 | distance = distance.norm(dim=-1).mean().item() 279 | 280 | if plot_path: 281 | plot_matching_result(src_image, trg_image, src_kps.cpu().numpy(),trg_kps.cpu().numpy(), pred_trg_kps.cpu().numpy(), output_path=plot_path) 282 | 283 | return distance -------------------------------------------------------------------------------- /region_utils/ui_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | import gradio as gr 7 | import torch 8 | import time 9 | import functools 10 | 11 | from .drag import drag 12 | 13 | # 1. Common utils 14 | def timeit(func): 15 | """Decorator to measure the execution time of a function.""" 16 | @functools.wraps(func) 17 | def wrapper(*args, **kwargs): 18 | start_time = time.time() 19 | result = func(*args, **kwargs) 20 | end_time = time.time() 21 | elapsed_time = end_time - start_time 22 | print(f"Function {func.__name__!r} took {elapsed_time:.4f} seconds to complete.") 23 | return result 24 | 25 | return wrapper 26 | 27 | def get_W_H(max_length, aspect_ratio): 28 | height = int(max_length / aspect_ratio) if aspect_ratio >= 1 else max_length 29 | width = max_length if aspect_ratio >= 1 else int(max_length * aspect_ratio) 30 | height = int(height / 8) * 8 31 | width = int(width / 8) * 8 32 | return width, height 33 | 34 | def canvas_to_image_and_mask(canvas): 35 | """Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object.""" 36 | image = canvas["image"].copy() 37 | mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy() 38 | return image, mask 39 | 40 | def draw_mask_border(image, mask): 41 | """Find the contours of shapes in the mask and draw them on the image.""" 42 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 43 | cv2.drawContours(image, contours, -1, (255, 255, 255), 2) 44 | 45 | def mask_image(image, mask, color=[255,0,0], alpha=0.3): 46 | """Apply a binary mask to an image, highlighting the masked regions with a specified color and transparency.""" 47 | out = image.copy() 48 | out[mask == 1] = color 49 | out = cv2.addWeighted(out, alpha, image, 1-alpha, 0, out) 50 | return out 51 | 52 | def resize_image(canvas, canvas_length, image_length): 53 | """Fit the image to an appropriate size.""" 54 | if canvas is None: 55 | return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3 56 | 57 | image = canvas_to_image_and_mask(canvas)[0] 58 | image_h0, image_w0, _ = image.shape 59 | image_w1, image_h1 = get_W_H(image_length, image_w0 / image_h0) 60 | image = cv2.resize(image, (image_w1, image_h1)) 61 | 62 | # helpful when uploaded gradio image having width > length 63 | canvas_w1, canvas_h1 = (canvas_length, canvas_length) if image_h0 > image_w0 else get_W_H(canvas_length, image_w0 / image_h0) 64 | return (gr.Image(value=image, width=canvas_w1, height=canvas_h1),) * 3 65 | 66 | def wrong_upload(): 67 | """Prevent user to upload an image on the second column""" 68 | gr.Warning('You should upload an image on the left.') 69 | return None 70 | 71 | def save_data(original_image, input_image, preview_image, mask, src_points, trg_points, prompt, data_path): 72 | """ 73 | save following under `data_path` directory 74 | (1) original.png : original image 75 | (2) input_image.png : input image 76 | (3) preview_image.png : preview image 77 | (4) meta_data_mask.pkl : {'prompt' : text prompt, 78 | 'points' : [(x1,y1), (x2,y2), ..., (xn, yn)], 79 | 'mask' : binary mask np.uint8 (H, W)} 80 | (x1, y1), (x3, y3), ... are from source (handle) points 81 | (x2, y2), (x4, y4), ... are from target points 82 | """ 83 | os.makedirs(data_path, exist_ok=True) 84 | img_path = os.path.join(data_path, 'original_image.png') 85 | input_img_path = os.path.join(data_path, 'user_drag_region.png') 86 | preview_img_path = os.path.join(data_path, 'preview.png') 87 | meta_data_path = os.path.join(data_path, 'meta_data_mask.pkl') 88 | 89 | Image.fromarray(original_image).save(img_path) 90 | Image.fromarray(input_image).save(input_img_path) 91 | Image.fromarray(preview_image).save(preview_img_path) 92 | 93 | points = [point for pair in zip(src_points, trg_points) for point in pair] 94 | 95 | meta_data = { 96 | 'prompt': prompt, 97 | 'points': points, 98 | 'mask': mask 99 | } 100 | with open(meta_data_path, 'wb') as file: 101 | pickle.dump(meta_data, file, protocol=pickle.HIGHEST_PROTOCOL) 102 | 103 | @torch.no_grad() 104 | def run_process(canvas, input_image, preview_image, src_points, trg_points, prompt, start_t, end_t, steps, noise_scale, data_path, method, seed, progress=gr.Progress()): 105 | if canvas is None: 106 | return None 107 | 108 | original_image, mask = canvas_to_image_and_mask(canvas) 109 | mask = np.ones_like(mask) 110 | 111 | if src_points is None or len(src_points) == 0: 112 | return original_image 113 | 114 | drag_data = { 115 | 'ori_image' : original_image, 116 | 'preview' : preview_image, 117 | 'prompt' : prompt, 118 | 'mask' : mask, 119 | 'source' : src_points, 120 | 'target' : trg_points 121 | } 122 | 123 | return drag(drag_data, steps, start_t, end_t, noise_scale, seed, progress.tqdm, method, data_path) 124 | 125 | # 2. mask utils (region pairs) 126 | def clear_all_m(length): 127 | """Used to initialize all inputs in ui's region pair tab.""" 128 | return (gr.Image(value=None, height=length, width=length),) * 3 + \ 129 | ([], "A photo of an object.", "output/default", 20, 0.6, None, None) 130 | 131 | def draw_input_image_m(canvas, selected_masks): 132 | """Draw an image reflecting user's intentions.""" 133 | image = canvas_to_image_and_mask(canvas)[0] 134 | for i, mask in enumerate(selected_masks): 135 | color = [255, 0, 0] if i % 2 == 0 else [0, 0, 255] 136 | image = mask_image(image, mask, color=color, alpha=0.3) 137 | draw_mask_border(image, mask) 138 | return image 139 | 140 | @torch.no_grad() 141 | def region_pair_to_pts(src_region, trg_region, scale=1): 142 | """ 143 | Perform dense mapping beween one source (handle) and one target region. 144 | `scale` is set to 1/8 for mapping in SD latent space. 145 | """ 146 | 147 | def mask_min_max(tensor, mask, dim=None): 148 | """ 149 | Compute the masked max or min of a tensor along a given dimension. 150 | """ 151 | # Apply the mask by using a very small or very large number for min/max respectively 152 | masked_tensor = torch.where(mask, tensor, torch.inf) 153 | masked_min = torch.min(masked_tensor, dim=dim)[0] if dim is not None else torch.min(masked_tensor) 154 | 155 | masked_tensor = torch.where(mask, tensor, -torch.inf) 156 | masked_max = torch.max(masked_tensor, dim=dim)[0] if dim is not None else torch.max(masked_tensor) 157 | return masked_min, masked_max 158 | 159 | src_region = cv2.resize(src_region, (int(src_region.shape[1]*scale), int(src_region.shape[0]*scale))) 160 | trg_region = cv2.resize(trg_region, (int(trg_region.shape[1]*scale), int(trg_region.shape[0]*scale))) 161 | 162 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 163 | src_region = torch.from_numpy(src_region).to(device).bool() 164 | trg_region = torch.from_numpy(trg_region).to(device).bool() 165 | 166 | h, w = src_region.shape 167 | src_grid = trg_grid = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij'), dim=-1).to(device).float() 168 | trg_pts = trg_grid[torch.where(trg_region)] 169 | 170 | src_x_min, src_x_max = mask_min_max(src_grid[:, :, 1], mask=src_region) 171 | trg_x_min, trg_x_max = mask_min_max(trg_grid[:, :, 1], mask=trg_region) 172 | 173 | scale_x = (src_x_max - src_x_min) / (trg_x_max - trg_x_min).clamp(min=1e-4) 174 | 175 | trg_grid[:, :, 1] = ((trg_grid[:, :, 1] - trg_x_min) * scale_x + src_x_min) 176 | trg_grid[:, :, 1] = torch.where(trg_region, trg_grid[:, :, 1], 0) 177 | 178 | src_y_min, src_y_max = mask_min_max(src_grid[:, :, 0], mask=src_region, dim=0) 179 | trg_y_min, trg_y_max = mask_min_max(trg_grid[:, :, 0], mask=trg_region, dim=0) 180 | src_y_min, src_y_max = src_y_min[trg_grid[:, :, 1].int()], src_y_max[trg_grid[:, :, 1].int()] 181 | 182 | scale_y = (src_y_max - src_y_min) / (trg_y_max - trg_y_min).clamp(min=1e-4) 183 | trg_grid[:, :, 0] = ((trg_grid[:, :, 0] - trg_y_min) * scale_y + src_y_min) 184 | warp_trg_pts = trg_grid[torch.where(trg_region)] 185 | 186 | return warp_trg_pts[:, [1, 0]].int(), trg_pts[:, [1, 0]].int() 187 | 188 | def preview_out_image_m(canvas, selected_masks): 189 | """Preview the output image by directly copy-pasting pixel values.""" 190 | if canvas is None: 191 | return None, None, None 192 | image = canvas_to_image_and_mask(canvas)[0] 193 | 194 | if len(selected_masks) < 2: 195 | return image, None, None 196 | 197 | src_regions = selected_masks[0:-1:2] 198 | trg_regions = selected_masks[1::2] 199 | 200 | src_points, trg_points = map(torch.cat, zip(*[region_pair_to_pts(src_region, trg_region) for src_region, trg_region in zip(src_regions, trg_regions)])) 201 | src_idx, trg_idx = src_points[:, [1, 0]].cpu().numpy(), trg_points[:, [1, 0]].cpu().numpy() 202 | image[tuple(trg_idx.T)] = image[tuple(src_idx.T)] 203 | 204 | src_points, trg_points = map(torch.cat, zip(*[region_pair_to_pts(src_region, trg_region, scale=1/8) for src_region, trg_region in zip(src_regions, trg_regions)])) 205 | return image, src_points*8, trg_points*8 206 | 207 | def add_mask(canvas, selected_masks): 208 | """Add a drawn mask, and draw input image""" 209 | if canvas is None: 210 | return None 211 | image, mask = canvas_to_image_and_mask(canvas) 212 | if len(selected_masks) >= 1 and (mask == 0).all(): 213 | gr.Warning('Do not input empty region.') 214 | else: 215 | selected_masks.append(mask) 216 | return draw_input_image_m(canvas, selected_masks) 217 | 218 | def undo_mask(canvas, selected_masks): 219 | """Undo a drawn mask, and draw input image""" 220 | if len(selected_masks) > 0: 221 | selected_masks.pop() 222 | if canvas is None: 223 | return None 224 | return draw_input_image_m(canvas, selected_masks) 225 | 226 | def clear_masks(canvas, selected_masks): 227 | """Clear all drawn masks, and draw input image""" 228 | selected_masks.clear() 229 | if canvas is None: 230 | return None 231 | return draw_input_image_m(canvas, selected_masks) 232 | 233 | # 3. vertice utils (polygon pairs) 234 | def clear_all(length): 235 | """Used to initialize all inputs in ui's vertice pair tab.""" 236 | return (gr.Image(value=None, height=length, width=length),) * 3 + \ 237 | ([], [], "A photo of an object.", "output/default", 20, 0.6, None, None) 238 | 239 | def draw_input_image(canvas, selected_points, selected_shapes): 240 | """Draw input image with vertices.""" 241 | # Extract the image and mask from the canvas 242 | image, mask = canvas_to_image_and_mask(canvas) 243 | if mask.sum() > 0: 244 | gr.Info('The drawn mask is not used.') 245 | mask = np.ones_like(mask) 246 | # If a mask is present (i.e., sum of mask values > 0), non-masked parts will be darkened 247 | masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) if mask.sum() > 0 else image 248 | 249 | def draw_circle(image, point, text, color): 250 | font = cv2.FONT_HERSHEY_SIMPLEX 251 | font_scale = 0.3 252 | font_color = (255, 255, 255) 253 | font_thickness = 1 254 | text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0] 255 | bottom_left_corner = (point[0] - text_size[0] // 2, point[1] + text_size[1] // 2) 256 | 257 | cv2.circle(image, tuple(point), 8, color, -1) 258 | cv2.circle(image, tuple(point), 8, [255, 255, 255]) 259 | cv2.putText(image, text, bottom_left_corner, font, font_scale, font_color, font_thickness) 260 | 261 | def draw_polygon(image, points, shape, color): 262 | if len(points) != shape: 263 | return image 264 | mask = np.zeros(image.shape[:2], dtype=np.uint8) 265 | points = np.array(points).reshape((-1, 1, 2)) 266 | cv2.fillPoly(mask, [points], 1) 267 | return mask_image(image, mask, color, alpha=0.3) 268 | 269 | if len(selected_points) == 0: 270 | return masked_img 271 | 272 | start_idx = 1 273 | for points, shape in zip(selected_points, selected_shapes): 274 | src_pts, trg_pts = points[:shape], points[shape:] 275 | masked_img = draw_polygon(masked_img, src_pts, shape, color=[255, 0, 0]) 276 | masked_img = draw_polygon(masked_img, trg_pts, shape, color=[0, 0, 255]) 277 | 278 | for i, src_pt in enumerate(src_pts, start=start_idx): 279 | draw_circle(masked_img, src_pt, str(i), [255, 102, 102]) 280 | for i, trg_pt in enumerate(trg_pts, start=start_idx): 281 | draw_circle(masked_img, trg_pt, str(i), [102, 102, 255]) 282 | start_idx = i + 1 283 | 284 | return masked_img 285 | 286 | def transform_polygon(src_points, trg_points, scale=1): 287 | """ 288 | Perform dense mapping with source (handle) and target triangle or quadrilateral. 289 | """ 290 | def get_points_inside_polygon(points): 291 | points = np.array(points, dtype=np.int32) 292 | x_max, y_max = np.amax(points, axis=0) + 1 293 | mask = np.zeros((y_max, x_max), dtype=np.uint8) 294 | cv2.fillPoly(mask, [points], 1) 295 | return np.column_stack(np.where(mask == 1))[:, [1, 0]] 296 | 297 | if len(trg_points) not in [3, 4]: 298 | raise NotImplementedError('Only triangles and quadrilaterals are implemented') 299 | 300 | src_points, trg_points = np.float32(src_points)*scale, np.float32(trg_points)*scale 301 | M = (cv2.getAffineTransform if len(trg_points) == 3 else cv2.getPerspectiveTransform)(trg_points, src_points) 302 | points_inside = get_points_inside_polygon(trg_points) 303 | warped_points = (cv2.transform if len(trg_points) == 3 else cv2.perspectiveTransform)(np.array([points_inside], dtype=np.float32), M) 304 | 305 | return warped_points[0].astype(np.int32), points_inside 306 | 307 | def preview_out_image(canvas, selected_points, selected_shapes): 308 | if canvas is None: 309 | return None, None, None 310 | image = canvas_to_image_and_mask(canvas)[0] 311 | 312 | selected_points = selected_points.copy() 313 | selected_shapes = selected_shapes.copy() 314 | 315 | if len(selected_points) == 0: 316 | return image, None, None 317 | 318 | if len(selected_points[-1]) != selected_shapes[-1] * 2: 319 | selected_points.pop() 320 | selected_shapes.pop() 321 | if len(selected_points) == 0: 322 | return image, None, None 323 | 324 | src_points, trg_points = map(np.concatenate, zip(*[transform_polygon(sp[:ss], sp[ss:]) for sp, ss in zip(selected_points, selected_shapes)])) 325 | src_idx, trg_idx = src_points[:, [1, 0]], trg_points[:, [1, 0]] 326 | image[tuple(trg_idx.T)] = image[tuple(src_idx.T)] 327 | 328 | src_points, trg_points = map(np.concatenate, zip(*[transform_polygon(sp[:ss], sp[ss:], scale=1/8) for sp, ss in zip(selected_points, selected_shapes)])) 329 | return image, src_points*8, trg_points*8 330 | 331 | def add_point(canvas, shape, selected_points, selected_shapes, evt: gr.SelectData): 332 | """Collect the selected point, and draw the input image.""" 333 | if canvas is None: 334 | return None 335 | 336 | def is_valid_quadrilateral(p1, p2, p3, p4): 337 | def orientation(p, q, r): 338 | val = (q[1] - p[1]) * (r[0] - q[0]) - (q[0] - p[0]) * (r[1] - q[1]) 339 | if val == 0: 340 | return 0 # collinear 341 | return 1 if val > 0 else 2 # clock or counterclock wise 342 | 343 | def do_intersect(a1, a2, b1, b2): 344 | o1 = orientation(a1, a2, b1) 345 | o2 = orientation(a1, a2, b2) 346 | o3 = orientation(b1, b2, a1) 347 | o4 = orientation(b1, b2, a2) 348 | return (o1 != o2 and o3 != o4) and not (o1 == 0 or o2 == 0 or o3 == 0 or o4 == 0) 349 | 350 | # Check for collinearity and intersection between non-adjacent edges 351 | return not (orientation(p1, p2, p3) == 0 or 352 | orientation(p1, p2, p4) == 0 or 353 | orientation(p2, p3, p4) == 0 or 354 | orientation(p1, p3, p4) == 0 or 355 | do_intersect(p1, p2, p3, p4) or 356 | do_intersect(p2, p3, p1, p4)) 357 | 358 | if len(selected_points) == 0 or len(selected_points[-1]) == 2 * selected_shapes[-1]: 359 | selected_points.append([evt.index]) 360 | selected_shapes.append(shape+3) 361 | else: 362 | selected_points[-1].append(evt.index) 363 | if selected_shapes[-1] == 4 and len(selected_points[-1]) % selected_shapes[-1] == 0: 364 | if not is_valid_quadrilateral(*selected_points[-1][-4:]): 365 | gr.Warning('The drawn quadrilateral is not valid.') 366 | selected_points[-1].pop() 367 | 368 | return draw_input_image(canvas, selected_points, selected_shapes) 369 | 370 | def update_shape(canvas, shape, selected_points, selected_shapes): 371 | """Allow users to switch between different shape options""" 372 | if canvas is None: 373 | return None 374 | if len(selected_points) == 0: 375 | return draw_input_image(canvas, selected_points, selected_shapes) 376 | if len(selected_points[-1]) == selected_shapes[-1] * 2: 377 | return draw_input_image(canvas, selected_points, selected_shapes) 378 | 379 | selected_shapes[-1] = shape + 3 380 | selected_points[-1] = selected_points[-1][ : (shape+3)*2] 381 | return draw_input_image(canvas, selected_points, selected_shapes) 382 | 383 | def undo_point(canvas, shape, selected_points, selected_shapes): 384 | """Remove the last added point, and draw the input image.""" 385 | if canvas is None: 386 | return None 387 | if len(selected_points) == 0: 388 | return draw_input_image(canvas, selected_points, selected_shapes) 389 | 390 | selected_points[-1].pop() 391 | if len(selected_points[-1]) == 0: 392 | selected_points.pop() 393 | selected_shapes.pop() 394 | update_shape(canvas, shape, selected_points, selected_shapes) 395 | return draw_input_image(canvas, selected_points, selected_shapes) 396 | 397 | def clear_points(canvas, selected_points, selected_shapes): 398 | """Clear all the points""" 399 | selected_points.clear() 400 | selected_shapes.clear() 401 | if canvas is None: 402 | return None 403 | return draw_input_image(canvas, selected_points, selected_shapes) 404 | 405 | # 4. region utils (region + point pair) 406 | def draw_input_image_r(canvas, selected_points): 407 | image, mask = canvas_to_image_and_mask(canvas) 408 | image = mask_image(image, mask, color=[255, 0, 0], alpha=0.3) 409 | draw_mask_border(image, mask) 410 | for idx, point in enumerate(selected_points, start=1): 411 | if idx % 2 == 0: 412 | cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) 413 | cv2.arrowedLine(image, last_point, point, (255, 255, 255), 4, tipLength=0.5) 414 | else: 415 | cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) 416 | last_point = point 417 | return image 418 | 419 | 420 | def region_to_points(region_mask, selected_points, scale=1): 421 | """ 422 | Process a region mask and a list of selected points to find corresponding 423 | source and target points within the region scaled by a factor. 424 | """ 425 | def resize_region_mask(mask, scale_factor): 426 | """Resize the region mask by a scale factor.""" 427 | H, W = mask.shape 428 | new_H, new_W = (int(H * scale_factor), int(W * scale_factor)) 429 | return cv2.resize(mask, (new_W, new_H)), (new_H, new_W) 430 | 431 | def find_contours(mask): 432 | """Find contours in the mask.""" 433 | return cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] 434 | 435 | def scale_points(points, scale_factor): 436 | """Scale points by a scale factor.""" 437 | return (np.array(points) * scale_factor).astype(np.int32) 438 | 439 | def draw_mask_from_contours(contour, shape): 440 | """Draws a mask from given contours.""" 441 | mask = np.zeros(shape, dtype=np.uint8) 442 | contour = contour[:, np.newaxis, :] if contour.ndim == 2 else contour 443 | cv2.drawContours(mask, [contour], -1, color=255, thickness=cv2.FILLED) 444 | return mask 445 | 446 | def find_points_in_mask(mask): 447 | """Find the coordinates of non-zero points in the mask.""" 448 | return np.column_stack(np.where(mask)).astype(np.int32)[:, [1, 0]] 449 | 450 | def filter_points_by_bounds(source_points, target_points, bounds): 451 | """Filter points by checking if they fall within the given bounds.""" 452 | height, width = bounds 453 | src_condition = ( 454 | (source_points[:, 0] >= 0) & (source_points[:, 0] < width) & 455 | (source_points[:, 1] >= 0) & (source_points[:, 1] < height) 456 | ) 457 | trg_condition = ( 458 | (target_points[:, 0] >= 0) & (target_points[:, 0] < width) & 459 | (target_points[:, 1] >= 0) & (target_points[:, 1] < height) 460 | ) 461 | condition = src_condition & trg_condition 462 | return source_points[condition], target_points[condition] 463 | 464 | def find_matching_points(source_region, source_points, target_points): 465 | """Find source points in source_region and their matching target points.""" 466 | match_indices = np.all(source_points[:, None] == source_region, axis=2).any(axis=1) 467 | return source_points[match_indices], target_points[match_indices] 468 | 469 | def interpolate_points(region_points, reference_points, directions, max_num=100): 470 | """Interpolate points within a region based on reference points and their directions.""" 471 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 472 | 473 | # Convert input numpy arrays to PyTorch tensors and send them to the appropriate device 474 | region_points = torch.from_numpy(region_points).half().to(device) 475 | reference_points = torch.from_numpy(reference_points).half().to(device) 476 | directions = torch.from_numpy(directions).half().to(device) 477 | 478 | if len(reference_points) < 2: 479 | return (region_points + directions).int().to('cpu').numpy() 480 | 481 | if len(reference_points) > max_num: 482 | indices = torch.linspace(0, len(reference_points) - 1, steps=max_num).long().to(device) 483 | reference_points = reference_points[indices] 484 | directions = directions[indices] 485 | 486 | distance = torch.norm(region_points.unsqueeze(1) - reference_points.unsqueeze(0), dim=-1) 487 | _, indices = torch.sort(distance, dim=1) 488 | indices = indices[:, :min(4, reference_points.shape[0])] 489 | 490 | directions = torch.gather(directions.unsqueeze(0).expand(region_points.size(0), -1, -1), 1, indices.unsqueeze(-1).expand(-1, -1, directions.size(-1))) 491 | 492 | inv_distance = 1 / (torch.gather(distance, 1, indices) + 1e-4) 493 | weight = inv_distance / inv_distance.sum(dim=1, keepdim=True) 494 | 495 | estimated_direction = (weight.unsqueeze(-1) * directions).sum(dim=1) 496 | return (region_points + estimated_direction).round().int().to('cpu').numpy() 497 | 498 | resized_mask, new_size = resize_region_mask(region_mask, scale) 499 | 500 | contours = find_contours(resized_mask) 501 | source_points = scale_points(selected_points[0:-1:2], scale) 502 | target_points = scale_points(selected_points[1::2], scale) 503 | 504 | source_regions = [np.zeros((0, 2)),]; target_regions = [np.zeros((0, 2)),] 505 | for contour in contours: 506 | # find point pairs used to manipulate the region inside this contour 507 | source_contour = contour[:, 0, :] 508 | source_region_points = find_points_in_mask(draw_mask_from_contours(source_contour, new_size)) 509 | source, target = find_matching_points(source_region_points, source_points, target_points) 510 | 511 | # interplote to find motion of contour points and points inside 512 | if len(source) == 0: 513 | continue 514 | target_contour = interpolate_points(source_contour, source, target - source) 515 | interpolated_target_points = interpolate_points(source_region_points, source, target - source) 516 | 517 | # similar to above, this step ensures that we can have a reference point for each point inside the target region 518 | target_region_points = find_points_in_mask(draw_mask_from_contours(target_contour, new_size)) 519 | interpolated_source_points = interpolate_points(target_region_points, interpolated_target_points, source_region_points - interpolated_target_points) 520 | 521 | filtered_source, filtered_target = filter_points_by_bounds(interpolated_source_points, target_region_points, new_size) 522 | source_regions.append(filtered_source) 523 | target_regions.append(filtered_target) 524 | 525 | return np.concatenate(source_regions).astype(np.int32), np.concatenate(target_regions).astype(np.int32) 526 | 527 | def preview_out_image_r(canvas, selected_points): 528 | if canvas is None: 529 | return None 530 | image, region_mask = canvas_to_image_and_mask(canvas) 531 | 532 | if len(selected_points) < 2: 533 | return image, None, None 534 | 535 | src_points, trg_points = region_to_points(region_mask, selected_points) 536 | image[trg_points[:, 1], trg_points[:, 0]] = image[src_points[:, 1], src_points[:, 0]] 537 | src_points, trg_points = region_to_points(region_mask, selected_points, scale=1/8) 538 | return image, src_points*8, trg_points*8 539 | 540 | def add_point_r(canvas, selected_points, evt: gr.SelectData): 541 | if canvas is None: 542 | return None 543 | selected_points.append(evt.index) 544 | return draw_input_image_r(canvas, selected_points) 545 | 546 | def undo_point_r(canvas, selected_points): 547 | if canvas is None: 548 | return None 549 | if len(selected_points) > 0: 550 | selected_points.pop() 551 | return draw_input_image_r(canvas, selected_points) 552 | 553 | def clear_points_r(canvas, selected_points): 554 | if canvas is None: 555 | return None 556 | selected_points.clear() 557 | return draw_input_image_r(canvas, selected_points) 558 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.0 2 | diffusers==0.30.1 3 | transformers==4.40.2 4 | safetensors==0.4.3 5 | einops==0.7.0 6 | lpips==0.1.4 7 | cmake==3.25.0 8 | gradio==3.47.1 9 | matplotlib==3.7.1 10 | numpy==1.24.1 11 | opencv-python==4.8.0.76 12 | pandas==2.0.2 13 | Pillow==9.3.0 14 | torch==2.0.1 15 | torchaudio==2.0.2 16 | torchvision==0.15.2 17 | tqdm==4.65.0 18 | xformers==0.0.20 19 | yacs==0.1.8 -------------------------------------------------------------------------------- /run_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from tqdm import tqdm 5 | import gradio as gr 6 | 7 | from region_utils.drag import drag, get_drag_data, get_meta_data 8 | from region_utils.evaluator import DragEvaluator 9 | 10 | # Setting up the argument parser 11 | parser = argparse.ArgumentParser(description='Run the drag operation.') 12 | parser.add_argument('--data_dir', type=str, default='drag_data/dragbench-dr/') # OR 'drag_data/dragbench-sr/' 13 | args = parser.parse_args() 14 | 15 | evaluator = DragEvaluator() 16 | all_distances = []; all_lpips = [] 17 | 18 | data_dir = args.data_dir 19 | data_dirs = [dirpath for dirpath, dirnames, _ in os.walk(data_dir) if not dirnames] 20 | 21 | start_t = 0.5 22 | end_t = 0.2 23 | steps = 20 24 | noise_scale = 1. 25 | seed = 42 26 | 27 | for data_path in tqdm(data_dirs): 28 | # Region-based Inputs for Editing 29 | drag_data = get_drag_data(data_path) 30 | ori_image = drag_data['ori_image'] 31 | out_image = drag(drag_data, steps, start_t, end_t, noise_scale, seed, progress=gr.Progress()) 32 | 33 | # Point-based Inputs for Evaluation 34 | meta_data_path = os.path.join(data_path, 'meta_data.pkl') 35 | prompt, _, source, target = get_meta_data(meta_data_path) 36 | 37 | all_distances.append(evaluator.compute_distance(ori_image, out_image, source, target, method='sd', prompt=prompt)) 38 | all_lpips.append(evaluator.compute_lpips(ori_image, out_image)) 39 | 40 | if all_distances: 41 | mean_dist = torch.tensor(all_distances).mean().item() 42 | mean_lpips = torch.tensor(all_lpips).mean().item() 43 | print(f'MD: {mean_dist:.4f}\nLPIPS: {mean_lpips:.4f}\n') -------------------------------------------------------------------------------- /ui.py: -------------------------------------------------------------------------------- 1 | import torch 2 | if not torch.cuda.is_available(): 3 | raise RuntimeError('CUDA is not available, but required.') 4 | 5 | import gradio as gr 6 | from region_utils.ui_utils import * 7 | from region_utils.drag import sd_version 8 | 9 | LENGTH = 400 # length of image in Gradio App, you can adjust it according to your screen size 10 | GEN_SIZE = {'v1-5': 512, 'v2-1': 768, 'xl': 1024}[sd_version] # Default generated image size 11 | 12 | def main(): 13 | with gr.Blocks() as demo: 14 | gr_length = gr.Number(value=LENGTH, visible=False, precision=0) 15 | gr_gen_size = gr.Number(value=GEN_SIZE, visible=False, precision=0) 16 | 17 | selected_masks = gr.State(value=[]) 18 | src_points_m = gr.State(value=None); trg_points_m = gr.State(value=None) 19 | 20 | selected_points = gr.State(value=[]) 21 | selected_shapes = gr.State(value=[]) 22 | src_points = gr.State(value=None); trg_points = gr.State(value=None) 23 | 24 | selected_points_r = gr.State(value=[]) 25 | src_points_r = gr.State(value=None); trg_points_r = gr.State(value=None) 26 | 27 | seed = gr.Number(value=42, label="Generation Seed", precision=0, visible=False) 28 | start_t = gr.Number(value=0.5, visible=False) 29 | end_t = gr.Number(value=0.2, visible=False) 30 | 31 | # layout definition 32 | with gr.Row(): 33 | gr.Markdown(""" 34 | # Official Implementation of [RegionDrag](https://arxiv.org/abs/2407.18247) 35 | #### Explore our detailed [User Guide](https://github.com/LuJingyi-John/RegionDrag/blob/main/README.md) for interface instructions. 36 | """) 37 | 38 | with gr.Row(): 39 | with gr.Tab(label='Region pairs'): 40 | with gr.Row(): 41 | with gr.Column(): 42 | gr.Markdown("""

1. Upload image and add regions

""") 43 | canvas_m = gr.Image(type="numpy", tool="sketch", label=" ", height=LENGTH, width=LENGTH) 44 | with gr.Row(): 45 | resize_button_m = gr.Button("Fit Canvas") 46 | add_mask_button = gr.Button("Add Region") 47 | with gr.Column(): 48 | gr.Markdown("""

2. View Input

""") 49 | input_image_m = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=False) 50 | with gr.Row(): 51 | undo_mask_button = gr.Button("Undo Region") 52 | clear_mask_button = gr.Button("Clear Region") 53 | with gr.Column(): 54 | gr.Markdown("""

Results

""") 55 | output_image_m = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=False) 56 | with gr.Row(): 57 | run_button_m = gr.Button("Run Drag") 58 | clear_all_button_m = gr.Button("Clear All") 59 | 60 | with gr.Tab(label='Polygon pairs'): 61 | with gr.Row(): 62 | with gr.Column(): 63 | gr.Markdown("""

1. Upload image

""") 64 | canvas = gr.Image(type="numpy", tool="sketch", label=" ", height=LENGTH, width=LENGTH, interactive=True) 65 | with gr.Row(): 66 | resize_button = gr.Button("Fit Canvas") 67 | with gr.Column(): 68 | gr.Markdown("""

2. Click vertices

""") 69 | input_image = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=True) 70 | with gr.Row(): 71 | undo_point_button = gr.Button("Undo Point") 72 | clear_point_button = gr.Button("Clear Point") 73 | with gr.Column(): 74 | gr.Markdown("""

Results

""") 75 | output_image = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=False) 76 | with gr.Row(): 77 | run_button = gr.Button("Run Drag") 78 | clear_all_button = gr.Button("Clear All") 79 | shape = gr.Radio(choices=['▲ Tri', '■ Quad'], value='■ Quad', type='index', label='Mask Shape', interactive=True) 80 | 81 | with gr.Tab(label='Region + points'): 82 | with gr.Row(): 83 | with gr.Column(): 84 | gr.Markdown("""

1. Upload image and draw regions

""") 85 | canvas_r = gr.Image(type="numpy", tool="sketch", label=" ", height=LENGTH, width=LENGTH) 86 | with gr.Row(): 87 | resize_button_r = gr.Button("Fit Canvas") 88 | with gr.Column(): 89 | gr.Markdown("""

2. Click points to control regions

""") 90 | input_image_r = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=True) 91 | with gr.Row(): 92 | undo_point_button_r = gr.Button("Undo Point") 93 | clear_point_button_r = gr.Button("Clear Point") 94 | with gr.Column(): 95 | gr.Markdown("""

Results

""") 96 | output_image_r = gr.Image(type="numpy", label=" ", height=LENGTH, width=LENGTH, interactive=False) 97 | with gr.Row(): 98 | run_button_r = gr.Button("Run Drag") 99 | clear_all_button_r = gr.Button("Clear All") 100 | 101 | with gr.Tab("Generation Parameters"): 102 | with gr.Row(): 103 | prompt = gr.Textbox(label="Prompt describing output image (Optional)", value='A photo of an object.') 104 | data_path = gr.Textbox(value='output/default', label="Output path") 105 | with gr.Row(): 106 | steps = gr.Slider(minimum=20, maximum=100, value=20, step=20, label='Sampling steps', interactive=True) 107 | noise_scale = gr.Slider(minimum=0, maximum=1.6, value=0.6, step=0.2, label='Handle Noise Scale', interactive=True) # alpha 108 | method = gr.Dropdown(choices=['InstantDrag', 'Encode then CP', 'CP then Encode'], value='InstantDrag', label='Method', interactive=True) 109 | 110 | clear_all_button_m.click( 111 | clear_all_m, 112 | [gr_length], 113 | [canvas_m, input_image_m, output_image_m, selected_masks, prompt, data_path, steps, noise_scale, src_points_m, trg_points_m] 114 | ) 115 | canvas_m.clear( 116 | clear_all_m, 117 | [gr_length], 118 | [canvas_m, input_image_m, output_image_m, selected_masks, prompt, data_path, steps, noise_scale, src_points_m, trg_points_m] 119 | ) 120 | resize_button_m.click( 121 | clear_masks, 122 | [canvas_m, selected_masks], 123 | [input_image_m] 124 | ).then( 125 | resize_image, 126 | [canvas_m, gr_length, gr_gen_size], 127 | [canvas_m, input_image_m, output_image_m] 128 | ) 129 | add_mask_button.click( 130 | add_mask, 131 | [canvas_m, selected_masks], 132 | [input_image_m] 133 | ).then( 134 | preview_out_image_m, 135 | [canvas_m, selected_masks], 136 | [output_image_m, src_points_m, trg_points_m] 137 | ) 138 | undo_mask_button.click( 139 | undo_mask, 140 | [canvas_m, selected_masks], 141 | [input_image_m] 142 | ).then( 143 | preview_out_image_m, 144 | [canvas_m, selected_masks], 145 | [output_image_m, src_points_m, trg_points_m] 146 | ) 147 | clear_mask_button.click( 148 | clear_masks, 149 | [canvas_m, selected_masks], 150 | [input_image_m] 151 | ).then( 152 | preview_out_image_m, 153 | [canvas_m, selected_masks], 154 | [output_image_m, src_points_m, trg_points_m] 155 | ) 156 | run_button_m.click( 157 | preview_out_image_m, 158 | [canvas_m, selected_masks], 159 | [output_image_m, src_points_m, trg_points_m] 160 | ).then( 161 | run_process, 162 | [canvas_m, input_image_m, output_image_m, src_points_m, trg_points_m, prompt, start_t, end_t, steps, noise_scale, data_path, method, seed], 163 | [output_image_m] 164 | ) 165 | 166 | clear_all_button.click( 167 | clear_all, 168 | [gr_length], 169 | [canvas, input_image, output_image, selected_points, selected_shapes, prompt, data_path, steps, noise_scale, src_points, trg_points] 170 | ) 171 | canvas.clear( 172 | clear_all, 173 | [gr_length], 174 | [canvas, input_image, output_image, selected_points, selected_shapes, prompt, data_path, steps, noise_scale, src_points, trg_points] 175 | ) 176 | resize_button.click( 177 | clear_points, 178 | [canvas, selected_points, selected_shapes], 179 | [input_image] 180 | ).then( 181 | resize_image, 182 | [canvas, gr_length, gr_gen_size], 183 | [canvas, input_image, output_image] 184 | ) 185 | canvas.edit( 186 | draw_input_image, 187 | [canvas, selected_points, selected_shapes], 188 | input_image 189 | ).then( 190 | preview_out_image, 191 | [canvas, selected_points, selected_shapes], 192 | [output_image, src_points, trg_points] 193 | ) 194 | shape.change( 195 | update_shape, 196 | [canvas, shape, selected_points, selected_shapes], 197 | [input_image] 198 | ).then( 199 | preview_out_image, 200 | [canvas, selected_points, selected_shapes], 201 | [output_image, src_points, trg_points] 202 | ) 203 | input_image.upload( 204 | wrong_upload, 205 | outputs=[input_image] 206 | ) 207 | input_image.select( 208 | add_point, 209 | [canvas, shape, selected_points, selected_shapes], 210 | [input_image] 211 | ).then( 212 | preview_out_image, 213 | [canvas, selected_points, selected_shapes], 214 | [output_image, src_points, trg_points] 215 | ) 216 | undo_point_button.click( 217 | undo_point, 218 | [canvas, shape, selected_points, selected_shapes], 219 | [input_image] 220 | ).then( 221 | preview_out_image, 222 | [canvas, selected_points, selected_shapes], 223 | [output_image, src_points, trg_points] 224 | ) 225 | clear_point_button.click( 226 | clear_points, 227 | [canvas, selected_points, selected_shapes], 228 | [input_image] 229 | ).then( 230 | preview_out_image, 231 | [canvas, selected_points, selected_shapes], 232 | [output_image, src_points, trg_points] 233 | ) 234 | run_button.click( 235 | preview_out_image, 236 | [canvas, selected_points, selected_shapes], 237 | [output_image, src_points, trg_points] 238 | ).then( 239 | run_process, 240 | [canvas, input_image, output_image, src_points, trg_points, prompt, start_t, end_t, steps, noise_scale, data_path, method, seed], 241 | [output_image] 242 | ) 243 | 244 | clear_all_button_r.click( 245 | clear_all_m, 246 | [gr_length], 247 | [canvas_r, input_image_r, output_image_r, selected_points_r, prompt, data_path, steps, noise_scale, src_points_r, trg_points_r] 248 | ) 249 | canvas_r.clear( 250 | clear_all_m, 251 | [gr_length], 252 | [canvas_r, input_image_r, output_image_r, selected_points_r, prompt, data_path, steps, noise_scale, src_points_r, trg_points_r] 253 | ) 254 | resize_button_r.click( 255 | clear_points_r, 256 | [canvas_r, selected_points_r], 257 | [input_image_r] 258 | ).then( 259 | resize_image, 260 | [canvas_r, gr_length, gr_gen_size], 261 | [canvas_r, input_image_r, output_image_r] 262 | ) 263 | canvas_r.edit( 264 | draw_input_image_r, 265 | [canvas_r, selected_points_r], 266 | [input_image_r] 267 | ).then( 268 | preview_out_image_r, 269 | [canvas_r, selected_points_r], 270 | [output_image_r, src_points_r, trg_points_r] 271 | ) 272 | input_image_r.upload( 273 | wrong_upload, 274 | outputs=[input_image_r] 275 | ) 276 | input_image_r.select( 277 | add_point_r, 278 | [canvas_r, selected_points_r], 279 | [input_image_r] 280 | ).then( 281 | preview_out_image_r, 282 | [canvas_r, selected_points_r], 283 | [output_image_r, src_points_r, trg_points_r] 284 | ) 285 | undo_point_button_r.click( 286 | undo_point_r, 287 | [canvas_r, selected_points_r], 288 | [input_image_r] 289 | ).then( 290 | preview_out_image_r, 291 | [canvas_r, selected_points_r], 292 | [output_image_r, src_points_r, trg_points_r] 293 | ) 294 | clear_point_button_r.click( 295 | clear_points_r, 296 | [canvas_r, selected_points_r], 297 | [input_image_r] 298 | ).then( 299 | preview_out_image_r, 300 | [canvas_r, selected_points_r], 301 | [output_image_r, src_points_r, trg_points_r] 302 | ) 303 | run_button_r.click( 304 | preview_out_image_r, 305 | [canvas_r, selected_points_r], 306 | [output_image_r, src_points_r, trg_points_r] 307 | ).then( 308 | run_process, 309 | [canvas_r, input_image_r, output_image_r, src_points_r, trg_points_r, prompt, start_t, end_t, steps, noise_scale, data_path, method, seed], 310 | [output_image_r] 311 | ) 312 | 313 | demo.queue().launch(share=True, debug=True) 314 | 315 | if __name__ == '__main__': 316 | main() --------------------------------------------------------------------------------