├── .gitattributes ├── README.md ├── adversarial_latent_optimization.py ├── checkpoints └── download.sh ├── generate_prompts.py ├── get_model.py ├── image_latent_mapping.py ├── ptp_utils.py ├── seq_aligner.py ├── temp └── 1000 │ └── temp.sh ├── test_image.py ├── test_quality.py ├── third_party └── download.sh └── utils_sgm.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [[NeurIPS 2023] Content-based Unrestricted Adversarial Attack](https://openreview.net/pdf?id=gO60SSGOMy) 2 | 3 | Zhaoyu Chen, Bo Li, Shuang Wu, Kaixun Jiang, Shouhong Ding, Wenqiang Zhang 4 | 5 | This repository offers Pytorch code to reproduce results from the paper. Please consider citing our paper if you find it interesting or helpful to your research. 6 | 7 | ``` 8 | @inproceedings{ 9 | chen2023contentbased, 10 | title={Content-based Unrestricted Adversarial Attack}, 11 | author={Zhaoyu Chen and Bo Li and Shuang Wu and Kaixun Jiang and Shouhong Ding and Wenqiang Zhang}, 12 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 13 | year={2023}, 14 | url={https://openreview.net/forum?id=gO60SSGOMy} 15 | } 16 | ``` 17 | 18 | 19 | ## Requirements 20 | 21 | - Python == 3.8.0 22 | - Pytorch == 1.12.1 23 | - torchvision == 0.13.1 24 | - CUDA == 11.3 25 | - timm == 0.9.5 26 | - TensorFlow == 2.11.0 27 | - diffuseers == 0.3.0 28 | - huggingface-hub == 0.11.1 29 | - pyiqa == 0.1.6.3 30 | 31 | 32 | ## Quick Start 33 | 34 | - **Prepare models** 35 | 36 | Download the pretrained models and their checkpoints. 37 | 38 | ```bash 39 | cd checkpoints 40 | ./download.sh 41 | ``` 42 | 43 | - **Prepare datasets** 44 | 45 | We obtain the datasets from Natural-Color-Fool. 46 | 47 | ```bash 48 | cd third_party 49 | ./download.sh 50 | ``` 51 | 52 | - **Generate the prompts** 53 | 54 | Here, we use BLIP-v2 to automatically generate corresponding prompts. 55 | 56 | ```bash 57 | TRANSFORMERS_OFFLINE=1 python3 generate_prompts.py 58 | ``` 59 | 60 | - **Image latent mapping** 61 | 62 | We use null-text embedding to map images into the latent space. 63 | 64 | ```bash 65 | python3 image_latent_mapping.py 66 | ``` 67 | 68 | - **Adversarial latent optimization** 69 | 70 | After the latent is processed offline, we perform latent optimization to obtain adversarial examples. 71 | 72 | ```bash 73 | CUDA_VISIBLE_DEVICES=0 python3 adversarial_latent_optimization.py --model mnv2 --beta 0.1 --alpha 0.04 --steps 10 --norm 2 --start 0 --end 1000 --mu 1 --eps 0.1 74 | ``` 75 | 76 | - **Evaluate the accuracy** 77 | 78 | Infer the model with images. 79 | 80 | ```bash 81 | python3 test_model.py --model MODEL_NAME --img_path IMAGE_SAVE_PATH 82 | ``` 83 | 84 | - **Evaluate the image quality** 85 | 86 | Test the image quality with images. 87 | 88 | ```bash 89 | python3 test_quality.py --metric METRIC --img_path IMAGE_SAVE_PATH 90 | ``` 91 | 92 | ## License 93 | The project is only free for academic research purposes but has no authorization for commerce. Part of the code is modified from Prompt-to-Prompt. -------------------------------------------------------------------------------- /adversarial_latent_optimization.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | import argparse 4 | from get_model import get_model 5 | from typing import Optional, Union, Tuple, List, Callable, Dict 6 | from tqdm import tqdm, trange 7 | import torch 8 | from diffusers import StableDiffusionPipeline, DDIMScheduler 9 | import torch.nn.functional as nnf 10 | import numpy as np 11 | import abc 12 | import ptp_utils 13 | import seq_aligner 14 | import shutil 15 | from torch.optim.adam import Adam 16 | from PIL import Image 17 | import time 18 | import torchvision 19 | import torch.backends.cudnn as cudnn 20 | import torchvision.transforms as transforms 21 | import torch.nn.functional as F 22 | from utils_sgm import register_hook_for_resnet, register_hook_for_densenet 23 | 24 | 25 | ''' 26 | CUDA_VISIBLE_DEVICES=0 python3 adversarial_latent_optimization.py --model mnv2 --beta 0.1 --alpha 0.04 --steps 10 --norm 2 --start 0 --end 1000 --mu 1 --eps 0.1 27 | ''' 28 | ############## Initialize ##################### 29 | parser = argparse.ArgumentParser(description='Adversarial Content Attack') 30 | parser.add_argument('--model', type=str, default='resnet50', help='model') 31 | parser.add_argument('--alpha', type=float, default=0.04, help='step size') 32 | parser.add_argument('--beta', type=float, default=0.1, help='mse factor') 33 | parser.add_argument('--eps', type=float, default=0.1, help='perturbation value') 34 | parser.add_argument('--steps', type=int, default=10, help='attack steps') 35 | parser.add_argument('--norm', type=int, default=2, help='loss norm') 36 | parser.add_argument('--lp', type=str, default='linf', help='perturbation norm') 37 | parser.add_argument('--start', default=0, type=int, help='img start') 38 | parser.add_argument('--end', default=1000, type=int, help='img end') 39 | parser.add_argument('--prefix', type=str, default='ACA-test', help='filename') 40 | parser.add_argument('--target', default=-1, type=int, help='target class, -1 is untargeted attack') 41 | parser.add_argument('--seed', default=0, type=int, help='random seed') 42 | parser.add_argument('--mu', default=1, type=float, help='momentum factor') 43 | 44 | 45 | args = parser.parse_args() 46 | print(args) 47 | 48 | torch.manual_seed(args.seed) 49 | torch.cuda.manual_seed(args.seed) 50 | np.random.seed(args.seed) 51 | torch.Generator().manual_seed(args.seed) 52 | 53 | 54 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 55 | print('==> Preparing Model..') 56 | image_size = (224, 224) 57 | if args.model == 'vit' or args.model == 'adv_resnet152_denoise': 58 | print('Using 0.5 Nor...') 59 | mean = [0.5, 0.5, 0.5] 60 | std = [0.5, 0.5, 0.5] 61 | elif args.model == 'mvit': 62 | mean = [0, 0, 0] 63 | std = [1, 1, 1] 64 | image_size = (320, 320) 65 | else: 66 | mean = [0.485, 0.456, 0.406] 67 | std = [0.229, 0.224, 0.225] 68 | 69 | 70 | mean = torch.Tensor(mean).cuda() 71 | std = torch.Tensor(std).cuda() 72 | 73 | net = get_model(args.model) 74 | if device == 'cuda': 75 | net.to(device) 76 | cudnn.benchmark = True 77 | net.eval() 78 | # # Apply SGM 79 | # if args.model in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']: 80 | # register_hook_for_resnet(net, arch=args.model, gamma=0.5) 81 | # else: 82 | # raise ValueError('Current code only supports resnet/densenet. ' 83 | # 'You can extend this code to other architectures.') 84 | net.cuda() 85 | 86 | class EmptyControl: 87 | def step_callback(self, x_t): 88 | return x_t 89 | 90 | def between_steps(self): 91 | return 92 | 93 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 94 | return attn 95 | 96 | 97 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 98 | if type(image_path) is str: 99 | image = np.array(Image.open(image_path))[:, :, :3] 100 | else: 101 | image = image_path 102 | h, w, c = image.shape 103 | left = min(left, w-1) 104 | right = min(right, w - left - 1) 105 | top = min(top, h - left - 1) 106 | bottom = min(bottom, h - top - 1) 107 | image = image[top:h-bottom, left:w-right] 108 | h, w, c = image.shape 109 | if h < w: 110 | offset = (w - h) // 2 111 | image = image[:, offset:offset + h] 112 | elif w < h: 113 | offset = (h - w) // 2 114 | image = image[offset:offset + w] 115 | image = np.array(Image.fromarray(image).resize((512, 512))) 116 | return image 117 | 118 | @torch.no_grad() 119 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False): 120 | if low_resource: 121 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 122 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 123 | else: 124 | latents_input = torch.cat([latents] * 2) 125 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 126 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 127 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 128 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 129 | latents = controller.step_callback(latents) 130 | return latents 131 | 132 | def limitation01(y): 133 | idx = (y > 1) 134 | y[idx] = (torch.tanh(1000*(y[idx]-1))+10000)/10001 135 | idx = (y < 0) 136 | y[idx] = (torch.tanh(1000*(y[idx])))/10000 137 | return y 138 | 139 | def norm_l2(Z): 140 | """Compute norms over all but the first dimension""" 141 | return Z.view(Z.shape[0], -1).norm(dim=1)[:,None,None,None] 142 | 143 | @torch.no_grad() 144 | def adversarial_latent_optimization( 145 | model, 146 | prompt: List[str], 147 | controller, 148 | num_inference_steps: int = 50, 149 | guidance_scale: Optional[float] = 7.5, 150 | generator: Optional[torch.Generator] = None, 151 | latent: Optional[torch.FloatTensor] = None, 152 | uncond_embeddings=None, 153 | start_time=50, 154 | label=None, 155 | raw_img=None 156 | ): 157 | batch_size = len(prompt) 158 | ptp_utils.register_attention_control(model, controller) 159 | height = width = 512 160 | 161 | text_input = model.tokenizer( 162 | prompt, 163 | padding="max_length", 164 | max_length=model.tokenizer.model_max_length, 165 | truncation=True, 166 | return_tensors="pt", 167 | ) 168 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 169 | max_length = text_input.input_ids.shape[-1] 170 | if uncond_embeddings is None: 171 | uncond_input = model.tokenizer( 172 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 173 | ) 174 | uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 175 | else: 176 | uncond_embeddings_ = None 177 | 178 | latent, latents = ptp_utils.init_latent(latent, model, height, width, generator, batch_size) 179 | print("Latent", latent.shape, "Latents", latents.shape) 180 | model.scheduler.set_timesteps(num_inference_steps) 181 | 182 | best_latent = latents 183 | ori_latents = latents.clone().detach() 184 | adv_latents = latents.clone().detach() 185 | print(latents.max(), latents.min()) 186 | success = True 187 | momentum = 0 188 | for k in range(args.steps): 189 | latents = adv_latents 190 | for i, t in enumerate(model.scheduler.timesteps[-start_time:]): 191 | # print(i, t) 192 | if uncond_embeddings_ is None: 193 | context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]) 194 | else: 195 | context = torch.cat([uncond_embeddings_, text_embeddings]) 196 | latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False) 197 | 198 | image = None 199 | with torch.enable_grad(): 200 | latents_last = latents.detach().clone() 201 | latents_last.requires_grad = True 202 | latents_t = (1 / 0.18215 * latents_last) 203 | image = model.vae.decode(latents_t)['sample'] 204 | image = (image / 2 + 0.5) 205 | 206 | # print(4, image.max(), image.min()) 207 | image = limitation01(image) 208 | image_m = F.interpolate(image, image_size) 209 | 210 | # print(1, image_m.max(), image_m.min()) 211 | image_m = image_m - mean[None,:,None,None] 212 | image_m = image_m / std[None,:,None,None] 213 | outputs = net(image_m) 214 | _, predicted = outputs.max(1) 215 | 216 | if args.target == -1: 217 | if label != predicted: 218 | best_latent = adv_latents 219 | success = False 220 | else: 221 | if args.target == predicted: 222 | best_latent = adv_latents 223 | success = False 224 | 225 | 226 | if args.target == -1: 227 | loss_ce = torch.nn.CrossEntropyLoss()(outputs, torch.Tensor([label]).long().cuda()) 228 | else: 229 | loss_ce = -torch.nn.CrossEntropyLoss()(outputs, torch.Tensor([args.target]).long().cuda()) 230 | 231 | # print(3, image_m.max(), image_m.min(), raw_img.max(), raw_img.min()) 232 | loss_mse = args.beta * torch.norm(image_m-raw_img, p=args.norm).mean() 233 | loss = loss_ce - loss_mse 234 | loss.backward() 235 | 236 | print('*' * 50) 237 | print('Loss', loss.item(), 'Loss_ce', loss_ce.item(), 'Loss_mse', loss_mse.item()) 238 | print(k, 'Predicted:', label, predicted, loss.item()) 239 | # print('Grad:', latents_last.grad.min(), latents_last.grad.max()) 240 | 241 | l1_grad = latents_last.grad / torch.norm(latents_last.grad, p=1) 242 | # print('L1 Grad:', l1_grad.min(), l1_grad.max()) 243 | momentum = args.mu * momentum + l1_grad 244 | if args.lp == 'linf': 245 | adv_latents = adv_latents + torch.sign(momentum) * args.alpha 246 | noise = (adv_latents - ori_latents).clamp(-args.eps, args.eps) 247 | elif args.lp == 'l2': 248 | adv_latents = adv_latents + args.alpha * momentum.detach() / norm_l2(momentum.detach()) 249 | noise = (adv_latents - ori_latents) * args.eps / norm_l2(adv_latents - ori_latents).clamp(min=args.eps) 250 | adv_latents = ori_latents + noise 251 | latents = adv_latents.detach() 252 | 253 | if success: 254 | best_latent = latents 255 | 256 | # Return Best Attack 257 | latents = best_latent 258 | for i, t in enumerate(model.scheduler.timesteps[-start_time:]): 259 | # print(i, t) 260 | if uncond_embeddings_ is None: 261 | context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]) 262 | else: 263 | context = torch.cat([uncond_embeddings_, text_embeddings]) 264 | latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False) 265 | latents = (1 / 0.18215 * latents) 266 | image = model.vae.decode(latents)['sample'] 267 | image = (image / 2 + 0.5) 268 | # print(4, image.max(), image.min()) 269 | image = limitation01(image) 270 | image = F.interpolate(image, image_size) 271 | # print(2, image.max(), image.min()) 272 | 273 | image = image.clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy() 274 | image = (image * 255).astype(np.uint8) 275 | return image, best_latent, success 276 | 277 | 278 | if __name__ == '__main__': 279 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) 280 | MY_TOKEN = 'your_token' 281 | LOW_RESOURCE = False 282 | NUM_DDIM_STEPS = 50 283 | GUIDANCE_SCALE = 7.5 284 | MAX_NUM_WORDS = 77 285 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 286 | 287 | ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler).to(device) 288 | try: 289 | ldm_stable.disable_xformers_memory_efficient_attention() 290 | except AttributeError: 291 | print("Attribute disable_xformers_memory_efficient_attention() is missing") 292 | tokenizer = ldm_stable.tokenizer 293 | 294 | image_nums = 1000 295 | img_path = 'temp/1000/inversion' 296 | raw_img_path = 'third_party/Natural-Color-Fool/dataset/images' 297 | all_prompts = open('temp/1000/prompts.txt').readlines() 298 | all_latents = torch.load('temp/1000/all_latents.pth') 299 | all_uncons = torch.load('temp/1000/all_uncons.pth') 300 | all_labels = open('third_party/Natural-Color-Fool/dataset/labels.txt').readlines() 301 | 302 | img_list = os.listdir(img_path) 303 | img_list.sort() 304 | cnt = 0 305 | save_path = './temp/' + args.prefix + '-' + str(args.mu) + '-' + args.model + '-' + str(args.alpha) + '-' + str(args.beta) + '-' + str(args.norm) + '-' + str(args.steps) + '-Clip-' + str(args.eps) + '-' + args.lp + '-' + str(args.target) + '/' 306 | print("Save Path:", save_path) 307 | if not os.path.exists(save_path): 308 | os.mkdir(save_path) 309 | for i in trange(args.start, args.end): 310 | if not os.path.exists(os.path.join(save_path, img_list[i].split('.')[0]+'.png')): 311 | img_path = os.path.join(img_path, img_list[i]) 312 | idx = int(img_list[i].split('.')[0]) 313 | prompt = all_prompts[idx].strip() 314 | x_t = all_latents[idx].cuda() 315 | uncond_embeddings = all_uncons[idx].cuda() 316 | label = int(all_labels[idx].strip()) - 1 317 | print(idx, label) 318 | pil_image = Image.open(os.path.join(raw_img_path, str(idx+1)+'.png')).convert('RGB').resize(image_size) 319 | raw_img = (torch.tensor(np.array(pil_image), device=device).unsqueeze(0)/255.).permute(0, 3, 1, 2) 320 | raw_img = raw_img - mean[None,:,None,None] 321 | raw_img = raw_img / std[None,:,None,None] 322 | 323 | prompts = [prompt] 324 | controller = EmptyControl() 325 | image_inv, x_t, success = adversarial_latent_optimization(ldm_stable, prompts, controller, latent=x_t, num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE, generator=None, uncond_embeddings=uncond_embeddings, label=label, raw_img=raw_img) 326 | ptp_utils.view_images([image_inv[0]], prefix=os.path.join(save_path, img_list[i].split('.')[0])) 327 | cnt += success 328 | print("Acc: ", cnt, '/', (i-args.start+1)) 329 | 330 | else: 331 | print(os.path.join(save_path, img_list[i].split('.')[0]), " has existed!") 332 | 333 | 334 | -------------------------------------------------------------------------------- /checkpoints/download.sh: -------------------------------------------------------------------------------- 1 | wget -c https://download.pytorch.org/models/resnet50-0676ba61.pth 2 | wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth -------------------------------------------------------------------------------- /generate_prompts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from lavis.models import load_model_and_preprocess 5 | 6 | 7 | ''' 8 | TRANSFORMERS_OFFLINE=1 python3 generate_prompts.py 9 | ''' 10 | 11 | 12 | if __name__ == '__main__': 13 | image_nums = 1000 14 | save_path = 'temp/1000' 15 | all_prompts = [ "" for i in range(1000)] 16 | 17 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 18 | model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device) 19 | # # load sample image 20 | img_path = 'third_party/Natural-Color-Fool/dataset/images' 21 | img_list = os.listdir(img_path) 22 | for i in range(image_nums): 23 | path = os.path.join(img_path, img_list[i]) 24 | idx = int(img_list[i].split('.')[0]) - 1 25 | raw_image = Image.open(path).convert("RGB") 26 | image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) 27 | ans = model.generate({"image": image, "prompt": "Question: Please give a detailed description of the image. Answer:"}) 28 | 29 | print(path, idx, ans) 30 | all_prompts[idx] = ans[0] 31 | 32 | 33 | with open(os.path.join(save_path, 'prompts.txt'), 'w') as f: 34 | for i in range(len(all_prompts)): 35 | f.write(all_prompts[i]+ '\n') -------------------------------------------------------------------------------- /get_model.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import os 4 | import timm 5 | 6 | def get_model(model): 7 | home_path = './' 8 | if model == 'resnet50': 9 | net = torchvision.models.resnet50() 10 | net.load_state_dict(torch.load(os.path.join(home_path, 'checkpoints/resnet50-0676ba61.pth'))) 11 | elif model == 'mnv2': 12 | net = torchvision.models.mobilenet_v2() 13 | net.load_state_dict(torch.load(os.path.join(home_path, 'checkpoints/mobilenet_v2-b0353104.pth'))) 14 | return net -------------------------------------------------------------------------------- /image_latent_mapping.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Tuple, List, Callable, Dict 2 | from tqdm import tqdm, trange 3 | import torch 4 | import os 5 | from diffusers import StableDiffusionPipeline, DDIMScheduler 6 | import torch.nn.functional as nnf 7 | import numpy as np 8 | import abc 9 | import ptp_utils 10 | import seq_aligner 11 | import shutil 12 | from torch.optim.adam import Adam 13 | from PIL import Image 14 | import time 15 | import cv2 16 | 17 | class LocalBlend: 18 | def get_mask(self, maps, alpha, use_pool): 19 | k = 1 20 | maps = (maps * alpha).sum(-1).mean(1) 21 | if use_pool: 22 | maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k)) 23 | mask = nnf.interpolate(maps, size=(x_t.shape[2:])) 24 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] 25 | mask = mask.gt(self.th[1-int(use_pool)]) 26 | mask = mask[:1] + mask 27 | return mask 28 | 29 | def __call__(self, x_t, attention_store): 30 | self.counter += 1 31 | if self.counter > self.start_blend: 32 | 33 | maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] 34 | maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps] 35 | maps = torch.cat(maps, dim=1) 36 | mask = self.get_mask(maps, self.alpha_layers, True) 37 | if self.substruct_layers is not None: 38 | maps_sub = ~self.get_mask(maps, self.substruct_layers, False) 39 | mask = mask * maps_sub 40 | mask = mask.float() 41 | x_t = x_t[:1] + mask * (x_t - x_t[:1]) 42 | return x_t 43 | 44 | def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, start_blend=0.2, th=(.3, .3)): 45 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) 46 | for i, (prompt, words_) in enumerate(zip(prompts, words)): 47 | if type(words_) is str: 48 | words_ = [words_] 49 | for word in words_: 50 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer) 51 | alpha_layers[i, :, :, :, :, ind] = 1 52 | 53 | if substruct_words is not None: 54 | substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) 55 | for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)): 56 | if type(words_) is str: 57 | words_ = [words_] 58 | for word in words_: 59 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer) 60 | substruct_layers[i, :, :, :, :, ind] = 1 61 | self.substruct_layers = substruct_layers.to(device) 62 | else: 63 | self.substruct_layers = None 64 | self.alpha_layers = alpha_layers.to(device) 65 | self.start_blend = int(start_blend * NUM_DDIM_STEPS) 66 | self.counter = 0 67 | self.th=th 68 | 69 | 70 | 71 | class EmptyControl: 72 | def step_callback(self, x_t): 73 | return x_t 74 | 75 | def between_steps(self): 76 | return 77 | 78 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 79 | return attn 80 | 81 | 82 | class AttentionControl(abc.ABC): 83 | 84 | def step_callback(self, x_t): 85 | return x_t 86 | 87 | def between_steps(self): 88 | return 89 | 90 | @property 91 | def num_uncond_att_layers(self): 92 | return self.num_att_layers if LOW_RESOURCE else 0 93 | 94 | @abc.abstractmethod 95 | def forward (self, attn, is_cross: bool, place_in_unet: str): 96 | raise NotImplementedError 97 | 98 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 99 | if self.cur_att_layer >= self.num_uncond_att_layers: 100 | if LOW_RESOURCE: 101 | attn = self.forward(attn, is_cross, place_in_unet) 102 | else: 103 | h = attn.shape[0] 104 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 105 | self.cur_att_layer += 1 106 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 107 | self.cur_att_layer = 0 108 | self.cur_step += 1 109 | self.between_steps() 110 | return attn 111 | 112 | def reset(self): 113 | self.cur_step = 0 114 | self.cur_att_layer = 0 115 | 116 | def __init__(self): 117 | self.cur_step = 0 118 | self.num_att_layers = -1 119 | self.cur_att_layer = 0 120 | 121 | class SpatialReplace(EmptyControl): 122 | 123 | def step_callback(self, x_t): 124 | if self.cur_step < self.stop_inject: 125 | b = x_t.shape[0] 126 | x_t = x_t[:1].expand(b, *x_t.shape[1:]) 127 | return x_t 128 | 129 | def __init__(self, stop_inject: float): 130 | super(SpatialReplace, self).__init__() 131 | self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS) 132 | 133 | 134 | class AttentionStore(AttentionControl): 135 | 136 | @staticmethod 137 | def get_empty_store(): 138 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 139 | "down_self": [], "mid_self": [], "up_self": []} 140 | 141 | def forward(self, attn, is_cross: bool, place_in_unet: str): 142 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 143 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 144 | self.step_store[key].append(attn) 145 | return attn 146 | 147 | def between_steps(self): 148 | if len(self.attention_store) == 0: 149 | self.attention_store = self.step_store 150 | else: 151 | for key in self.attention_store: 152 | for i in range(len(self.attention_store[key])): 153 | self.attention_store[key][i] += self.step_store[key][i] 154 | self.step_store = self.get_empty_store() 155 | 156 | def get_average_attention(self): 157 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 158 | return average_attention 159 | 160 | 161 | def reset(self): 162 | super(AttentionStore, self).reset() 163 | self.step_store = self.get_empty_store() 164 | self.attention_store = {} 165 | 166 | def __init__(self): 167 | super(AttentionStore, self).__init__() 168 | self.step_store = self.get_empty_store() 169 | self.attention_store = {} 170 | 171 | 172 | class AttentionControlEdit(AttentionStore, abc.ABC): 173 | 174 | def step_callback(self, x_t): 175 | if self.local_blend is not None: 176 | x_t = self.local_blend(x_t, self.attention_store) 177 | return x_t 178 | 179 | def replace_self_attention(self, attn_base, att_replace, place_in_unet): 180 | if att_replace.shape[2] <= 32 ** 2: 181 | attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 182 | return attn_base 183 | else: 184 | return att_replace 185 | 186 | @abc.abstractmethod 187 | def replace_cross_attention(self, attn_base, att_replace): 188 | raise NotImplementedError 189 | 190 | def forward(self, attn, is_cross: bool, place_in_unet: str): 191 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 192 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 193 | h = attn.shape[0] // (self.batch_size) 194 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 195 | attn_base, attn_repalce = attn[0], attn[1:] 196 | if is_cross: 197 | alpha_words = self.cross_replace_alpha[self.cur_step] 198 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce 199 | attn[1:] = attn_repalce_new 200 | else: 201 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet) 202 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 203 | return attn 204 | 205 | def __init__(self, prompts, num_steps: int, 206 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 207 | self_replace_steps: Union[float, Tuple[float, float]], 208 | local_blend: Optional[LocalBlend]): 209 | super(AttentionControlEdit, self).__init__() 210 | self.batch_size = len(prompts) 211 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device) 212 | if type(self_replace_steps) is float: 213 | self_replace_steps = 0, self_replace_steps 214 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 215 | self.local_blend = local_blend 216 | 217 | class AttentionReplace(AttentionControlEdit): 218 | 219 | def replace_cross_attention(self, attn_base, att_replace): 220 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) 221 | 222 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 223 | local_blend: Optional[LocalBlend] = None): 224 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 225 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) 226 | 227 | 228 | class AttentionRefine(AttentionControlEdit): 229 | 230 | def replace_cross_attention(self, attn_base, att_replace): 231 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) 232 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 233 | # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True) 234 | return attn_replace 235 | 236 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 237 | local_blend: Optional[LocalBlend] = None): 238 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 239 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) 240 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) 241 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 242 | 243 | 244 | class AttentionReweight(AttentionControlEdit): 245 | 246 | def replace_cross_attention(self, attn_base, att_replace): 247 | if self.prev_controller is not None: 248 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) 249 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 250 | # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True) 251 | return attn_replace 252 | 253 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, 254 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): 255 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 256 | self.equalizer = equalizer.to(device) 257 | self.prev_controller = controller 258 | 259 | 260 | def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], 261 | Tuple[float, ...]]): 262 | if type(word_select) is int or type(word_select) is str: 263 | word_select = (word_select,) 264 | equalizer = torch.ones(1, 77) 265 | 266 | for word, val in zip(word_select, values): 267 | inds = ptp_utils.get_word_inds(text, word, tokenizer) 268 | equalizer[:, inds] = val 269 | return equalizer 270 | 271 | def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): 272 | out = [] 273 | attention_maps = attention_store.get_average_attention() 274 | num_pixels = res ** 2 275 | for location in from_where: 276 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 277 | if item.shape[1] == num_pixels: 278 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 279 | out.append(cross_maps) 280 | out = torch.cat(out, dim=0) 281 | out = out.sum(0) / out.shape[0] 282 | return out.cpu() 283 | 284 | 285 | def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float], self_replace_steps: float, blend_words=None, equilizer_params=None) -> AttentionControlEdit: 286 | if blend_words is None: 287 | lb = None 288 | else: 289 | lb = LocalBlend(prompts, blend_word) 290 | if is_replace_controller: 291 | controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb) 292 | else: 293 | controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb) 294 | if equilizer_params is not None: 295 | eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"]) 296 | controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, 297 | self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb, controller=controller) 298 | return controller 299 | 300 | 301 | def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0): 302 | tokens = tokenizer.encode(prompts[select]) 303 | decoder = tokenizer.decode 304 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select) 305 | images = [] 306 | for i in range(len(tokens)): 307 | image = attention_maps[:, :, i] 308 | image = 255 * image / image.max() 309 | image = image.unsqueeze(-1).expand(*image.shape, 3) 310 | image = image.numpy().astype(np.uint8) 311 | image = np.array(Image.fromarray(image).resize((256, 256))) 312 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 313 | images.append(image) 314 | ptp_utils.view_images(np.stack(images, axis=0), prefix='cross_attention') 315 | 316 | 317 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], 318 | max_com=10, select: int = 0): 319 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) 320 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) 321 | images = [] 322 | for i in range(max_com): 323 | image = vh[i].reshape(res, res) 324 | image = image - image.min() 325 | image = 255 * image / image.max() 326 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) 327 | image = Image.fromarray(image).resize((256, 256)) 328 | image = np.array(image) 329 | images.append(image) 330 | ptp_utils.view_images(np.concatenate(images, axis=1), prefix='self_attention') 331 | 332 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 333 | if type(image_path) is str: 334 | image = np.array(Image.open(image_path))[:, :, :3] 335 | else: 336 | image = image_path 337 | h, w, c = image.shape 338 | left = min(left, w-1) 339 | right = min(right, w - left - 1) 340 | top = min(top, h - left - 1) 341 | bottom = min(bottom, h - top - 1) 342 | image = image[top:h-bottom, left:w-right] 343 | h, w, c = image.shape 344 | if h < w: 345 | offset = (w - h) // 2 346 | image = image[:, offset:offset + h] 347 | elif w < h: 348 | offset = (h - w) // 2 349 | image = image[offset:offset + w] 350 | image = np.array(Image.fromarray(image).resize((512, 512))) 351 | return image 352 | 353 | 354 | class NullInversion: 355 | 356 | def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 357 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps 358 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 359 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod 360 | beta_prod_t = 1 - alpha_prod_t 361 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 362 | pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output 363 | prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction 364 | return prev_sample 365 | 366 | def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 367 | timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep 368 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod 369 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] 370 | beta_prod_t = 1 - alpha_prod_t 371 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 372 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 373 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 374 | return next_sample 375 | 376 | def get_noise_pred_single(self, latents, t, context): 377 | noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] 378 | return noise_pred 379 | 380 | def get_noise_pred(self, latents, t, is_forward=True, context=None): 381 | latents_input = torch.cat([latents] * 2) 382 | if context is None: 383 | context = self.context 384 | guidance_scale = 1 if is_forward else GUIDANCE_SCALE 385 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 386 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 387 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 388 | if is_forward: 389 | latents = self.next_step(noise_pred, t, latents) 390 | else: 391 | latents = self.prev_step(noise_pred, t, latents) 392 | return latents 393 | 394 | @torch.no_grad() 395 | def latent2image(self, latents, return_type='np'): 396 | latents = 1 / 0.18215 * latents.detach() 397 | image = self.model.vae.decode(latents)['sample'] 398 | if return_type == 'np': 399 | image = (image / 2 + 0.5).clamp(0, 1) 400 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0] 401 | image = (image * 255).astype(np.uint8) 402 | return image 403 | 404 | @torch.no_grad() 405 | def image2latent(self, image): 406 | with torch.no_grad(): 407 | if type(image) is Image: 408 | image = np.array(image) 409 | if type(image) is torch.Tensor and image.dim() == 4: 410 | latents = image 411 | else: 412 | image = torch.from_numpy(image).float() / 127.5 - 1 413 | image = image.permute(2, 0, 1).unsqueeze(0).to(device) 414 | latents = self.model.vae.encode(image)['latent_dist'].mean 415 | latents = latents * 0.18215 416 | return latents 417 | 418 | @torch.no_grad() 419 | def init_prompt(self, prompt: str): 420 | uncond_input = self.model.tokenizer( 421 | [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, 422 | return_tensors="pt" 423 | ) 424 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] 425 | text_input = self.model.tokenizer( 426 | [prompt], 427 | padding="max_length", 428 | max_length=self.model.tokenizer.model_max_length, 429 | truncation=True, 430 | return_tensors="pt", 431 | ) 432 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] 433 | self.context = torch.cat([uncond_embeddings, text_embeddings]) 434 | self.prompt = prompt 435 | 436 | @torch.no_grad() 437 | def ddim_loop(self, latent): 438 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 439 | all_latent = [latent] 440 | latent = latent.clone().detach() 441 | for i in range(NUM_DDIM_STEPS): 442 | t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] 443 | noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) 444 | latent = self.next_step(noise_pred, t, latent) 445 | all_latent.append(latent) 446 | return all_latent 447 | 448 | @property 449 | def scheduler(self): 450 | return self.model.scheduler 451 | 452 | @torch.no_grad() 453 | def ddim_inversion(self, image): 454 | latent = self.image2latent(image) 455 | image_rec = self.latent2image(latent) 456 | ddim_latents = self.ddim_loop(latent) 457 | return image_rec, ddim_latents 458 | 459 | def null_optimization(self, latents, num_inner_steps, epsilon): 460 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 461 | uncond_embeddings_list = [] 462 | latent_cur = latents[-1] 463 | bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS) 464 | for i in range(NUM_DDIM_STEPS): 465 | uncond_embeddings = uncond_embeddings.clone().detach() 466 | uncond_embeddings.requires_grad = True 467 | optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) 468 | latent_prev = latents[len(latents) - i - 2] 469 | t = self.model.scheduler.timesteps[i] 470 | with torch.no_grad(): 471 | noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) 472 | for j in range(num_inner_steps): 473 | noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) 474 | noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond) 475 | latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) 476 | loss = nnf.mse_loss(latents_prev_rec, latent_prev) 477 | optimizer.zero_grad() 478 | loss.backward() 479 | optimizer.step() 480 | loss_item = loss.item() 481 | bar.update() 482 | if loss_item < epsilon + i * 2e-5: 483 | break 484 | for j in range(j + 1, num_inner_steps): 485 | bar.update() 486 | uncond_embeddings_list.append(uncond_embeddings[:1].detach()) 487 | with torch.no_grad(): 488 | context = torch.cat([uncond_embeddings, cond_embeddings]) 489 | latent_cur = self.get_noise_pred(latent_cur, t, False, context) 490 | bar.close() 491 | return uncond_embeddings_list 492 | 493 | def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False): 494 | self.init_prompt(prompt) 495 | ptp_utils.register_attention_control(self.model, None) 496 | image_gt = load_512(image_path, *offsets) 497 | if verbose: 498 | print("DDIM inversion...") 499 | image_rec, ddim_latents = self.ddim_inversion(image_gt) 500 | if verbose: 501 | print("Null-text optimization...") 502 | uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon) 503 | return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings 504 | 505 | 506 | def __init__(self, model): 507 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, 508 | set_alpha_to_one=False) 509 | self.model = model 510 | self.tokenizer = self.model.tokenizer 511 | self.model.scheduler.set_timesteps(NUM_DDIM_STEPS) 512 | self.prompt = None 513 | self.context = None 514 | 515 | @torch.no_grad() 516 | def text2image_ldm_stable( 517 | model, 518 | prompt: List[str], 519 | controller, 520 | num_inference_steps: int = 50, 521 | guidance_scale: Optional[float] = 7.5, 522 | generator: Optional[torch.Generator] = None, 523 | latent: Optional[torch.FloatTensor] = None, 524 | uncond_embeddings=None, 525 | start_time=50, 526 | return_type='image' 527 | ): 528 | batch_size = len(prompt) 529 | ptp_utils.register_attention_control(model, controller) 530 | height = width = 512 531 | 532 | text_input = model.tokenizer( 533 | prompt, 534 | padding="max_length", 535 | max_length=model.tokenizer.model_max_length, 536 | truncation=True, 537 | return_tensors="pt", 538 | ) 539 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 540 | max_length = text_input.input_ids.shape[-1] 541 | if uncond_embeddings is None: 542 | uncond_input = model.tokenizer( 543 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 544 | ) 545 | uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 546 | else: 547 | uncond_embeddings_ = None 548 | 549 | latent, latents = ptp_utils.init_latent(latent, model, height, width, generator, batch_size) 550 | model.scheduler.set_timesteps(num_inference_steps) 551 | for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])): 552 | if uncond_embeddings_ is None: 553 | context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]) 554 | else: 555 | context = torch.cat([uncond_embeddings_, text_embeddings]) 556 | latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False) 557 | 558 | if return_type == 'image': 559 | image = ptp_utils.latent2image(model.vae, latents) 560 | else: 561 | image = latents 562 | return image, latent 563 | 564 | 565 | 566 | def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, verbose=True, prefix='inversion'): 567 | if run_baseline: 568 | print("w.o. prompt-to-prompt") 569 | images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator) 570 | print("with prompt-to-prompt") 571 | images, x_t = text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, uncond_embeddings=uncond_embeddings) 572 | if verbose: 573 | ptp_utils.view_images(images, prefix=prefix) 574 | return images, x_t 575 | 576 | 577 | if __name__ == '__main__': 578 | # Load Stable Diffusion 579 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) 580 | MY_TOKEN = 'your token' 581 | LOW_RESOURCE = False 582 | NUM_DDIM_STEPS = 50 583 | GUIDANCE_SCALE = 7.5 584 | MAX_NUM_WORDS = 77 585 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 586 | ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=MY_TOKEN, scheduler=scheduler).to(device) 587 | 588 | try: 589 | ldm_stable.disable_xformers_memory_efficient_attention() 590 | except AttributeError: 591 | print("Attribute disable_xformers_memory_efficient_attention() is missing") 592 | tokenizer = ldm_stable.tokenizer 593 | null_inversion = NullInversion(ldm_stable) 594 | 595 | 596 | # Batch Images Load 597 | image_nums = 1000 598 | all_prompts = open('temp/1000/prompts.txt').readlines() 599 | all_latents = torch.zeros(image_nums, 4, 64, 64) 600 | all_uncons = torch.zeros(image_nums, NUM_DDIM_STEPS, 77, 768) 601 | 602 | img_filepath = 'third_party/Natural-Color-Fool/dataset/images' 603 | filepath_list = os.listdir(img_filepath) 604 | avg_ssim, avg_mse, avg_psnr = 0, 0, 0 605 | for i in trange(image_nums): 606 | img_path = os.path.join(img_filepath, filepath_list[i]) 607 | idx = int(filepath_list[i].split('.')[0]) - 1 608 | print(img_path, filepath_list[i], idx) 609 | raw_image = Image.open(img_path).convert("RGB") 610 | prompts = [all_prompts[idx].strip()] 611 | print(prompts) 612 | 613 | start = time.time() 614 | # Image Inversion 615 | (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(img_path, prompts[0], offsets=(0,0,0,0), verbose=True) 616 | print('Inversion Time:', time.time() - start) 617 | print(x_t.shape) 618 | print(len(uncond_embeddings), uncond_embeddings[0].shape) 619 | 620 | all_latents[idx] = x_t 621 | for k in range(NUM_DDIM_STEPS): 622 | all_uncons[idx][k] = uncond_embeddings[k] 623 | 624 | controller = AttentionStore() 625 | image_inv, x_t = run_and_display(prompts, controller, run_baseline=False, latent=x_t, uncond_embeddings=uncond_embeddings, verbose=False) 626 | print("showing from left to right: the ground truth image, the vq-autoencoder reconstruction, the null-text inverted image") 627 | ptp_utils.view_images([image_gt, image_inv[0]], prefix='1000/pair/%d' % (idx)) 628 | ptp_utils.view_images([image_gt], prefix='1000/original/%d' % (idx)) 629 | ptp_utils.view_images([image_inv[0]], prefix='1000/inversion/%d' % (idx)) 630 | 631 | 632 | torch.save(all_latents, 'temp/1000/all_latents.pth') 633 | torch.save(all_uncons, 'temp/1000/all_uncons.pth') -------------------------------------------------------------------------------- /ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image, ImageDraw, ImageFont 18 | import cv2 19 | from typing import Optional, Union, Tuple, List, Callable, Dict 20 | from IPython.display import display 21 | from tqdm.notebook import tqdm 22 | 23 | 24 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 25 | h, w, c = image.shape 26 | offset = int(h * .2) 27 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 28 | font = cv2.FONT_HERSHEY_SIMPLEX 29 | img[:h] = image 30 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 31 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 32 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) 33 | return img 34 | 35 | 36 | def view_images(images, num_rows=1, offset_ratio=0.02, prefix='test'): 37 | if type(images) is list: 38 | num_empty = len(images) % num_rows 39 | elif images.ndim == 4: 40 | num_empty = images.shape[0] % num_rows 41 | else: 42 | images = [images] 43 | num_empty = 0 44 | 45 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 46 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 47 | num_items = len(images) 48 | 49 | h, w, c = images[0].shape 50 | offset = int(h * offset_ratio) 51 | num_cols = num_items // num_rows 52 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 53 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 54 | for i in range(num_rows): 55 | for j in range(num_cols): 56 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 57 | i * num_cols + j] 58 | 59 | pil_img = Image.fromarray(image_) 60 | display(pil_img) 61 | pil_img.save(prefix + '.png') 62 | 63 | 64 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False): 65 | if low_resource: 66 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 67 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 68 | else: 69 | # print("Latent: ", latents.shape, latents.requires_grad) 70 | latents_input = torch.cat([latents] * 2) 71 | # print("Latent_input: ", latents_input.shape, latents_input.requires_grad) 72 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 73 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 74 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 75 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 76 | latents = controller.step_callback(latents) 77 | return latents 78 | 79 | 80 | def latent2image(vae, latents): 81 | latents = 1 / 0.18215 * latents 82 | image = vae.decode(latents)['sample'] 83 | image = (image / 2 + 0.5).clamp(0, 1) 84 | image = image.cpu().permute(0, 2, 3, 1).numpy() 85 | image = (image * 255).astype(np.uint8) 86 | return image 87 | 88 | 89 | def init_latent(latent, model, height, width, generator, batch_size): 90 | if latent is None: 91 | latent = torch.randn( 92 | (1, model.unet.in_channels, height // 8, width // 8), 93 | generator=generator, 94 | ) 95 | print(latent.shape) 96 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 97 | print("init:",latents.shape) 98 | return latent, latents 99 | 100 | 101 | @torch.no_grad() 102 | def text2image_ldm( 103 | model, 104 | prompt: List[str], 105 | controller, 106 | num_inference_steps: int = 50, 107 | guidance_scale: Optional[float] = 7., 108 | generator: Optional[torch.Generator] = None, 109 | latent: Optional[torch.FloatTensor] = None, 110 | ): 111 | register_attention_control(model, controller) 112 | height = width = 256 113 | batch_size = len(prompt) 114 | 115 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 116 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 117 | 118 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 119 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 120 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 121 | context = torch.cat([uncond_embeddings, text_embeddings]) 122 | 123 | model.scheduler.set_timesteps(num_inference_steps) 124 | for t in tqdm(model.scheduler.timesteps): 125 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 126 | 127 | image = latent2image(model.vqvae, latents) 128 | 129 | return image, latent 130 | 131 | 132 | @torch.no_grad() 133 | def text2image_ldm_stable( 134 | model, 135 | prompt: List[str], 136 | controller, 137 | num_inference_steps: int = 50, 138 | guidance_scale: float = 7.5, 139 | generator: Optional[torch.Generator] = None, 140 | latent: Optional[torch.FloatTensor] = None, 141 | low_resource: bool = False, 142 | ): 143 | register_attention_control(model, controller) 144 | height = width = 512 145 | batch_size = len(prompt) 146 | text_input = model.tokenizer( 147 | prompt, 148 | padding="max_length", 149 | max_length=model.tokenizer.model_max_length, 150 | truncation=True, 151 | return_tensors="pt", 152 | ) 153 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 154 | max_length = text_input.input_ids.shape[-1] 155 | uncond_input = model.tokenizer( 156 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 157 | ) 158 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 159 | 160 | context = [uncond_embeddings, text_embeddings] 161 | if not low_resource: 162 | context = torch.cat(context) 163 | print("Context: ", context.shape, context.requires_grad) 164 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 165 | 166 | # set timesteps 167 | extra_set_kwargs = {"offset": 1} 168 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 169 | cnt = 0 170 | for t in tqdm(model.scheduler.timesteps): 171 | print('-' * 40) 172 | print(cnt) 173 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 174 | cnt += 1 175 | 176 | image = latent2image(model.vae, latents) 177 | 178 | return image, latent 179 | 180 | 181 | def register_attention_control(model, controller): 182 | def ca_forward(self, place_in_unet): 183 | to_out = self.to_out 184 | if type(to_out) is torch.nn.modules.container.ModuleList: 185 | to_out = self.to_out[0] 186 | else: 187 | to_out = self.to_out 188 | 189 | def forward(x, context=None, mask=None): 190 | batch_size, sequence_length, dim = x.shape 191 | h = self.heads 192 | q = self.to_q(x) 193 | is_cross = context is not None 194 | context = context if is_cross else x 195 | k = self.to_k(context) 196 | v = self.to_v(context) 197 | q = self.reshape_heads_to_batch_dim(q) 198 | k = self.reshape_heads_to_batch_dim(k) 199 | v = self.reshape_heads_to_batch_dim(v) 200 | 201 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 202 | 203 | if mask is not None: 204 | mask = mask.reshape(batch_size, -1) 205 | max_neg_value = -torch.finfo(sim.dtype).max 206 | mask = mask[:, None, :].repeat(h, 1, 1) 207 | sim.masked_fill_(~mask, max_neg_value) 208 | 209 | # attention, what we cannot get enough of 210 | attn = sim.softmax(dim=-1) 211 | # print("attn:", attn.requires_grad, attn.grad_fn, attn.shape) 212 | attn = controller(attn, is_cross, place_in_unet) 213 | # print("control attn:", attn.requires_grad, attn.grad_fn, attn.shape) 214 | out = torch.einsum("b i j, b j d -> b i d", attn, v) 215 | out = self.reshape_batch_dim_to_heads(out) 216 | return to_out(out) 217 | 218 | return forward 219 | 220 | class DummyController: 221 | 222 | def __call__(self, *args): 223 | return args[0] 224 | 225 | def __init__(self): 226 | self.num_att_layers = 0 227 | 228 | if controller is None: 229 | controller = DummyController() 230 | 231 | def register_recr(net_, count, place_in_unet): 232 | if net_.__class__.__name__ == 'CrossAttention': 233 | net_.forward = ca_forward(net_, place_in_unet) 234 | return count + 1 235 | elif hasattr(net_, 'children'): 236 | for net__ in net_.children(): 237 | count = register_recr(net__, count, place_in_unet) 238 | return count 239 | 240 | cross_att_count = 0 241 | sub_nets = model.unet.named_children() 242 | for net in sub_nets: 243 | if "down" in net[0]: 244 | cross_att_count = cross_att_count + register_recr(net[1], 0, "down") 245 | elif "up" in net[0]: 246 | cross_att_count = cross_att_count + register_recr(net[1], 0, "up") 247 | elif "mid" in net[0]: 248 | cross_att_count = cross_att_count + register_recr(net[1], 0, "mid") 249 | 250 | controller.num_att_layers = cross_att_count 251 | 252 | 253 | def get_word_inds(text: str, word_place: int, tokenizer): 254 | split_text = text.split(" ") 255 | if type(word_place) is str: 256 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 257 | elif type(word_place) is int: 258 | word_place = [word_place] 259 | out = [] 260 | if len(word_place) > 0: 261 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 262 | cur_len, ptr = 0, 0 263 | 264 | for i in range(len(words_encode)): 265 | cur_len = cur_len + len(words_encode[i]) 266 | if ptr in word_place: 267 | out.append(i + 1) 268 | if cur_len >= len(split_text[ptr]): 269 | ptr = ptr + 1 270 | cur_len = 0 271 | return np.array(out) 272 | 273 | 274 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 275 | word_inds: Optional[torch.Tensor]=None): 276 | if type(bounds) is float: 277 | bounds = 0, bounds 278 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 279 | if word_inds is None: 280 | word_inds = torch.arange(alpha.shape[2]) 281 | alpha[: start, prompt_ind, word_inds] = 0 282 | alpha[start: end, prompt_ind, word_inds] = 1 283 | alpha[end:, prompt_ind, word_inds] = 0 284 | return alpha 285 | 286 | 287 | def get_time_words_attention_alpha(prompts, num_steps, 288 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 289 | tokenizer, max_num_words=77): 290 | if type(cross_replace_steps) is not dict: 291 | cross_replace_steps = {"default_": cross_replace_steps} 292 | if "default_" not in cross_replace_steps: 293 | cross_replace_steps["default_"] = (0., 1.) 294 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 295 | for i in range(len(prompts) - 1): 296 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 297 | i) 298 | for key, item in cross_replace_steps.items(): 299 | if key != "default_": 300 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 301 | for i, ind in enumerate(inds): 302 | if len(ind) > 0: 303 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 304 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 305 | return alpha_time_words 306 | -------------------------------------------------------------------------------- /seq_aligner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class ScoreParams: 19 | 20 | def __init__(self, gap, match, mismatch): 21 | self.gap = gap 22 | self.match = match 23 | self.mismatch = mismatch 24 | 25 | def mis_match_char(self, x, y): 26 | if x != y: 27 | return self.mismatch 28 | else: 29 | return self.match 30 | 31 | 32 | def get_matrix(size_x, size_y, gap): 33 | matrix = [] 34 | for i in range(len(size_x) + 1): 35 | sub_matrix = [] 36 | for j in range(len(size_y) + 1): 37 | sub_matrix.append(0) 38 | matrix.append(sub_matrix) 39 | for j in range(1, len(size_y) + 1): 40 | matrix[0][j] = j*gap 41 | for i in range(1, len(size_x) + 1): 42 | matrix[i][0] = i*gap 43 | return matrix 44 | 45 | 46 | def get_matrix(size_x, size_y, gap): 47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 50 | return matrix 51 | 52 | 53 | def get_traceback_matrix(size_x, size_y): 54 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) 55 | matrix[0, 1:] = 1 56 | matrix[1:, 0] = 2 57 | matrix[0, 0] = 4 58 | return matrix 59 | 60 | 61 | def global_align(x, y, score): 62 | matrix = get_matrix(len(x), len(y), score.gap) 63 | trace_back = get_traceback_matrix(len(x), len(y)) 64 | for i in range(1, len(x) + 1): 65 | for j in range(1, len(y) + 1): 66 | left = matrix[i, j - 1] + score.gap 67 | up = matrix[i - 1, j] + score.gap 68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 69 | matrix[i, j] = max(left, up, diag) 70 | if matrix[i, j] == left: 71 | trace_back[i, j] = 1 72 | elif matrix[i, j] == up: 73 | trace_back[i, j] = 2 74 | else: 75 | trace_back[i, j] = 3 76 | return matrix, trace_back 77 | 78 | 79 | def get_aligned_sequences(x, y, trace_back): 80 | x_seq = [] 81 | y_seq = [] 82 | i = len(x) 83 | j = len(y) 84 | mapper_y_to_x = [] 85 | while i > 0 or j > 0: 86 | if trace_back[i, j] == 3: 87 | x_seq.append(x[i-1]) 88 | y_seq.append(y[j-1]) 89 | i = i-1 90 | j = j-1 91 | mapper_y_to_x.append((j, i)) 92 | elif trace_back[i][j] == 1: 93 | x_seq.append('-') 94 | y_seq.append(y[j-1]) 95 | j = j-1 96 | mapper_y_to_x.append((j, -1)) 97 | elif trace_back[i][j] == 2: 98 | x_seq.append(x[i-1]) 99 | y_seq.append('-') 100 | i = i-1 101 | elif trace_back[i][j] == 4: 102 | break 103 | mapper_y_to_x.reverse() 104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 105 | 106 | 107 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 108 | x_seq = tokenizer.encode(x) 109 | y_seq = tokenizer.encode(y) 110 | score = ScoreParams(0, 1, -1) 111 | matrix, trace_back = global_align(x_seq, y_seq, score) 112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 113 | alphas = torch.ones(max_len) 114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 115 | mapper = torch.zeros(max_len, dtype=torch.int64) 116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1] 117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) 118 | return mapper, alphas 119 | 120 | 121 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 122 | x_seq = prompts[0] 123 | mappers, alphas = [], [] 124 | for i in range(1, len(prompts)): 125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 126 | mappers.append(mapper) 127 | alphas.append(alpha) 128 | return torch.stack(mappers), torch.stack(alphas) 129 | 130 | 131 | def get_word_inds(text: str, word_place: int, tokenizer): 132 | split_text = text.split(" ") 133 | if type(word_place) is str: 134 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 135 | elif type(word_place) is int: 136 | word_place = [word_place] 137 | out = [] 138 | if len(word_place) > 0: 139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 140 | cur_len, ptr = 0, 0 141 | 142 | for i in range(len(words_encode)): 143 | cur_len += len(words_encode[i]) 144 | if ptr in word_place: 145 | out.append(i + 1) 146 | if cur_len >= len(split_text[ptr]): 147 | ptr += 1 148 | cur_len = 0 149 | return np.array(out) 150 | 151 | 152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 153 | words_x = x.split(' ') 154 | words_y = y.split(' ') 155 | if len(words_x) != len(words_y): 156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" 157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") 158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 161 | mapper = np.zeros((max_len, max_len)) 162 | i = j = 0 163 | cur_inds = 0 164 | while i < max_len and j < max_len: 165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 167 | if len(inds_source_) == len(inds_target_): 168 | mapper[inds_source_, inds_target_] = 1 169 | else: 170 | ratio = 1 / len(inds_target_) 171 | for i_t in inds_target_: 172 | mapper[inds_source_, i_t] = ratio 173 | cur_inds += 1 174 | i += len(inds_source_) 175 | j += len(inds_target_) 176 | elif cur_inds < len(inds_source): 177 | mapper[i, j] = 1 178 | i += 1 179 | j += 1 180 | else: 181 | mapper[j, j] = 1 182 | i += 1 183 | j += 1 184 | 185 | return torch.from_numpy(mapper).float() 186 | 187 | 188 | 189 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 190 | x_seq = prompts[0] 191 | mappers = [] 192 | for i in range(1, len(prompts)): 193 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 194 | mappers.append(mapper) 195 | return torch.stack(mappers) 196 | 197 | -------------------------------------------------------------------------------- /temp/1000/temp.sh: -------------------------------------------------------------------------------- 1 | ... -------------------------------------------------------------------------------- /test_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.modules import loss 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import torch.backends.cudnn as cudnn 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | from torch.autograd import Variable 13 | import os 14 | import argparse 15 | from tqdm import tqdm, trange 16 | import numpy as np 17 | from PIL import Image 18 | from get_model import get_model 19 | 20 | 21 | ''' 22 | CUDA_VISIBLE_DEVICES=0 python3 test_image.py 23 | ''' 24 | 25 | ############## Initialize ##################### 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('--model', type=str, default='resnet50', help='cnn') 28 | parser.add_argument('--img_path', type=str, default='', help='cnn') 29 | parser.add_argument('--seed', default=0, type=int, help='random seed') 30 | 31 | args = parser.parse_args() 32 | print(args) 33 | 34 | home_path = './' 35 | 36 | torch.manual_seed(args.seed) 37 | torch.cuda.manual_seed(args.seed) 38 | np.random.seed(args.seed) 39 | 40 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 41 | 42 | 43 | if __name__ == '__main__': 44 | torch.set_printoptions(precision=7) 45 | img_root = args.img_path 46 | print(img_root) 47 | img_list = os.listdir(img_root) 48 | img_list.sort() 49 | print('Total Images', len(img_list)) 50 | 51 | # models = ['mnv2', 'inception_v3', 'resnet50', 'densenet161', 'resnet152', 'ef_b7', 'mvit', 'vit', 'swint', 'pvtv2'] #] 52 | models = ['mnv2'] # 53 | 54 | for model in models: 55 | if model == 'vit': 56 | print('Using 0.5 Nor...') 57 | mean = [0.5, 0.5, 0.5] 58 | std = [0.5, 0.5, 0.5] 59 | elif model == 'mvit' or model == 'vitb_adv' or model == 'covnext_l_adv': 60 | mean = [0, 0, 0] 61 | std = [1, 1, 1] 62 | else: 63 | mean = [0.485, 0.456, 0.406] 64 | std = [0.229, 0.224, 0.225] 65 | 66 | mean = torch.Tensor(mean).cuda() 67 | std = torch.Tensor(std).cuda() 68 | 69 | # Model 70 | net = get_model(model) 71 | 72 | if device == 'cuda': 73 | net.to(device) 74 | cudnn.benchmark = True 75 | net.eval() 76 | net.cuda() 77 | if model == 'inception_v3': 78 | image_size = (299, 299) 79 | elif model == 'mvit': 80 | image_size = (320, 320) 81 | else: 82 | image_size = (224, 224) 83 | # print(img_list) 84 | cnt = 0 85 | labels_f = open('third_party/Natural-Color-Fool/dataset/labels.txt').readlines() 86 | acc = 0 87 | 88 | for (i, img_p) in enumerate(img_list): 89 | pil_image = Image.open(os.path.join(img_root, img_p)).convert('RGB').resize(image_size) 90 | img = (torch.tensor(np.array(pil_image), device=device).unsqueeze(0)/255.).permute(0, 3, 1, 2) 91 | img = img - mean[None,:,None,None] 92 | img = img / std[None,:,None,None] 93 | out = net(img.cuda()) 94 | _, predicted = out.max(1) 95 | idx = int(img_p.split('.')[0]) 96 | label = int(labels_f[idx]) - 1 97 | if predicted[0] == label: 98 | acc += 1 99 | print('-' * 60) 100 | print(model, "ASR:", 100-acc/len(img_list)*100) 101 | 102 | 103 | -------------------------------------------------------------------------------- /test_quality.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | from PIL import Image 5 | import random 6 | import torchvision.transforms as transforms 7 | import torchvision 8 | import torch.utils.data as data 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | import os 13 | from tqdm import trange 14 | import numpy 15 | import numpy as np 16 | import math 17 | import pyiqa 18 | 19 | 20 | parser = argparse.ArgumentParser(description='Test Image Quality!') 21 | parser.add_argument('--img_path', type=str, default='./', help='cnn') 22 | parser.add_argument('--metric', type=str, default='musiq-koniq', help='cnn') 23 | parser.add_argument('--seed', type=int, default=42, help='cnn') 24 | 25 | 26 | args = parser.parse_args() 27 | print(args) 28 | 29 | torch.manual_seed(args.seed) 30 | torch.cuda.manual_seed(args.seed) 31 | np.random.seed(args.seed) 32 | random.seed(args.seed) 33 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 34 | 35 | 36 | def img_loader(path): 37 | try: 38 | with open(path, 'rb') as f: 39 | img = Image.open(path).convert('RGB') 40 | return img 41 | except IOError: 42 | print('Cannot load image ' + path) 43 | 44 | if __name__ == '__main__': 45 | img1_root = '/temp/1000/original' 46 | img_roots = [ 47 | args.img_path 48 | ] 49 | 50 | name = args.metric 51 | iqa_metric = pyiqa.create_metric(name, device=device) 52 | print(iqa_metric.lower_better) 53 | fid_metric = pyiqa.create_metric('fid', device=device) 54 | 55 | # print(f_list) 56 | for img2_root in img_roots: 57 | f_list = os.listdir(img2_root) 58 | nnima = 0 59 | for i in trange(len(f_list)): 60 | img2 = torch.Tensor(np.array(img_loader(os.path.join(img2_root, f_list[i])).resize((224, 224)))).unsqueeze(0).to(device).permute(0, 3, 1, 2)/255.0 61 | score_nr = iqa_metric(img2) 62 | nnima += score_nr 63 | print('*' * 60) 64 | print(img2_root) 65 | print('name', nnima/len(f_list)) 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /third_party/download.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/VL-Group/Natural-Color-Fool -------------------------------------------------------------------------------- /utils_sgm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def backward_hook(gamma): 7 | # implement SGM through grad through ReLU 8 | def _backward_hook(module, grad_in, grad_out): 9 | if isinstance(module, nn.ReLU): 10 | return (gamma * grad_in[0],) 11 | return _backward_hook 12 | 13 | 14 | def backward_hook_norm(module, grad_in, grad_out): 15 | # normalize the gradient to avoid gradient explosion or vanish 16 | std = torch.std(grad_in[0]) 17 | return (grad_in[0] / std,) 18 | 19 | 20 | def register_hook_for_resnet(model, arch, gamma): 21 | # There is only 1 ReLU in Conv module of ResNet-18/34 22 | # and 2 ReLU in Conv module ResNet-50/101/152 23 | if arch in ['resnet50', 'resnet101', 'resnet152']: 24 | gamma = np.power(gamma, 0.5) 25 | backward_hook_sgm = backward_hook(gamma) 26 | 27 | for name, module in model.named_modules(): 28 | if 'relu' in name and not '0.relu' in name: 29 | module.register_backward_hook(backward_hook_sgm) 30 | 31 | # e.g., 1.layer1.1, 1.layer4.2, ... 32 | # if len(name.split('.')) == 3: 33 | if len(name.split('.')) >= 2 and 'layer' in name.split('.')[-2]: 34 | module.register_backward_hook(backward_hook_norm) 35 | 36 | 37 | def register_hook_for_densenet(model, arch, gamma): 38 | # There are 2 ReLU in Conv module of DenseNet-121/169/201. 39 | gamma = np.power(gamma, 0.5) 40 | backward_hook_sgm = backward_hook(gamma) 41 | for name, module in model.named_modules(): 42 | if 'relu' in name and not 'transition' in name: 43 | module.register_backward_hook(backward_hook_sgm) 44 | --------------------------------------------------------------------------------