├── README.md
├── config.py
├── docs
├── bbox_as_condition.json
├── boxdiff.gif
├── boxdiff.png
├── example.png
└── visorgpt.gif
├── eval.py
├── pipeline
├── gligen_pipeline_boxdiff.py
└── sd_pipeline_boxdiff.py
├── requirements.txt
├── run_gligen_boxdiff.py
├── run_sd_boxdiff.py
└── utils
├── __init__.py
├── drawer.py
├── gaussian_smoothing.py
├── ptp_utils.py
└── vis_utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
BoxDiff 🎨 (ICCV 2023)
3 | BoxDiff: Text-to-Image Synthesis with Training-Free Box-Constrained Diffusion
4 |
5 | [Jinheng Xie](https://sierkinhane.github.io/)1 Yuexiang Li2 Yawen Huang2 Haozhe Liu2,3 Wentian Zhang2 Yefeng Zheng2 [Mike Zheng Shou](https://scholar.google.com/citations?hl=zh-CN&user=h1-3lSoAAAAJ&view_op=list_works&sortby=pubdate)1
6 |
7 | 1 National University of Singapore 2 Tencent Jarvis Lab 3 KAUST
8 |
9 | [](https://arxiv.org/abs/2307.10816)
10 |
11 |
12 |
13 |
14 |
15 | ### BoxDiff-SD-XL
16 | [A BoxDiff implementation based on SD-XL](https://github.com/Cominclip/BoxDiff-XL)
17 |
18 | ### Integration in diffusers
19 | Thanks to [@zjysteven](https://github.com/zjysteven) for his efforts. Below shows an example with `stable-diffusion-2-1-base`.
20 | ```
21 | import torch
22 | from PIL import Image, ImageDraw
23 | from copy import deepcopy
24 |
25 | from examples.community.pipeline_stable_diffusion_boxdiff import StableDiffusionBoxDiffPipeline
26 |
27 | def draw_box_with_text(img, boxes, names):
28 | colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
29 | img_new = deepcopy(img)
30 | draw = ImageDraw.Draw(img_new)
31 |
32 | W, H = img.size
33 | for bid, box in enumerate(boxes):
34 | draw.rectangle([box[0] * W, box[1] * H, box[2] * W, box[3] * H], outline=colors[bid % len(colors)], width=4)
35 | draw.text((box[0] * W, box[1] * H), names[bid], fill=colors[bid % len(colors)])
36 | return img_new
37 |
38 | pipe = StableDiffusionBoxDiffPipeline.from_pretrained(
39 | "stabilityai/stable-diffusion-2-1-base",
40 | torch_dtype=torch.float16,
41 | )
42 | pipe.to("cuda")
43 |
44 | # example 1
45 | prompt = "as the aurora lights up the sky, a herd of reindeer leisurely wanders on the grassy meadow, admiring the breathtaking view, a serene lake quietly reflects the magnificent display, and in the distance, a snow-capped mountain stands majestically, fantasy, 8k, highly detailed"
46 | phrases = [
47 | "aurora",
48 | "reindeer",
49 | "meadow",
50 | "lake",
51 | "mountain"
52 | ]
53 | boxes = [[1,3,512,202], [75,344,421,495], [1,327,508,507], [2,217,507,341], [1,135,509,242]]
54 |
55 | # example 2
56 | # prompt = "A rabbit wearing sunglasses looks very proud"
57 | # phrases = ["rabbit", "sunglasses"]
58 | # boxes = [[67,87,366,512], [66,130,364,262]]
59 |
60 | boxes = [[x / 512 for x in box] for box in boxes]
61 |
62 | images = pipe(
63 | prompt,
64 | boxdiff_phrases=phrases,
65 | boxdiff_boxes=boxes,
66 | boxdiff_kwargs={
67 | "attention_res": 16,
68 | "normalize_eot": True
69 | },
70 | num_inference_steps=50,
71 | guidance_scale=7.5,
72 | generator=torch.manual_seed(42),
73 | safety_checker=None
74 | ).images
75 |
76 | draw_box_with_text(images[0], boxes, phrases).save("output.png")
77 | ```
78 |
79 | ### Setup
80 | Note that we only test the code using PyTorch==1.12.0. You can build the environment via `pip` as follow:
81 | ```
82 | pip3 install -r requirements.txt
83 | ```
84 | To apply BoxDiff on GLIGEN pipeline, please install diffusers as follow:
85 | ```
86 | git clone git@github.com:gligen/diffusers.git
87 | pip3 install -e .
88 | ```
89 |
90 | ### Usage
91 | To add spatial control on the Stable Diffusion model, you can simply use `run_sd_boxdiff.py`. For example:
92 | ```
93 | CUDA_VISIBLE_DEVICES=0 python3 run_sd_boxdiff.py --prompt "as the aurora lights up the sky, a herd of reindeer leisurely wanders on the grassy meadow, admiring the breathtaking view, a serene lake quietly reflects the magnificent display, and in the distance, a snow-capped mountain stands majestically, fantasy, 8k, highly detailed" --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,21,22,23,24,25,26,27,28,29,30] --token_indices [3,12,21,30,46] --bbox [[1,3,512,202],[75,344,421,495],[1,327,508,507],[2,217,507,341],[1,135,509,242]] --refine False
94 | ```
95 | or another example:
96 | ```
97 | CUDA_VISIBLE_DEVICES=0 python3 run_sd_boxdiff.py --prompt "A rabbit wearing sunglasses looks very proud" --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9] --token_indices [2,4] --bbox [[67,87,366,512],[66,130,364,262]]
98 | ```
99 | Note that you can specify the token indices as the indices of words you want control in the text prompt and one token index has one corresponding conditoning box. `P` and `L` are hyper-parameters for the proposed constraints.
100 |
101 | When `--bbox` is not specified, there is a interface to draw bounding boxes as conditions.
102 | ```
103 | CUDA_VISIBLE_DEVICES=0 python3 run_sd_boxdiff.py --prompt "A rabbit wearing sunglasses looks very proud" --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9] --token_indices [2,4]
104 | ```
105 |
106 | To add spatial control on the GLIGEN model, you can simply use `run_gligen_boxdiff.py`. For example:
107 | ```
108 | CUDA_VISIBLE_DEVICES=0 python3 run_gligen_boxdiff.py --prompt "A rabbit wearing sunglasses looks very proud" --gligen_phrases ["a rabbit","sunglasses"] --P 0.2 --L 1 --seeds [1,2,3,4,5,6,7,8,9] --token_indices [2,4] --bbox [[67,87,366,512],[66,130,364,262]] --refine False
109 | ```
110 |
111 | The direcory structure of synthetic results are as follows:
112 | ```
113 | outputs/
114 | |-- text prompt/
115 | | |-- 0.png
116 | | |-- 0_canvas.png
117 | | |-- 1.png
118 | | |-- 1_canvas.png
119 | | |-- ...
120 | ```
121 | 
122 |
123 | ### Customize Your Layout
124 | [VisorGPT](https://github.com/Sierkinhane/VisorGPT) can customize layouts as spatial conditions for image synthesis using BoxDiff.
125 |
126 | ### Citation
127 | ```
128 | @InProceedings{Xie_2023_ICCV,
129 | author = {Xie, Jinheng and Li, Yuexiang and Huang, Yawen and Liu, Haozhe and Zhang, Wentian and Zheng, Yefeng and Shou, Mike Zheng},
130 | title = {BoxDiff: Text-to-Image Synthesis with Training-Free Box-Constrained Diffusion},
131 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
132 | year = {2023},
133 | pages = {7452-7461}
134 | }
135 | ```
136 |
137 | Acknowledgment - the code is highly based on the repository of [diffusers](https://github.com/huggingface/diffusers), [google](https://github.com/google/prompt-to-prompt), and [yuval-alaluf](https://github.com/yuval-alaluf).
138 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from pathlib import Path
3 | from typing import Dict, List
4 |
5 |
6 | @dataclass
7 | class RunConfig:
8 | # Guiding text prompt
9 | prompt: str
10 | # Whether to use Stable Diffusion v2.1
11 | sd_2_1: bool = False
12 | # Which token indices to alter with attend-and-excite
13 | token_indices: List[int] = None
14 | # Which random seeds to use when generating
15 | seeds: List[int] = field(default_factory=lambda: [42])
16 | # Path to save all outputs to
17 | output_path: Path = Path('./outputs')
18 | # Number of denoising steps
19 | n_inference_steps: int = 50
20 | # Text guidance scale
21 | guidance_scale: float = 7.5
22 | # Number of denoising steps to apply attend-and-excite
23 | max_iter_to_alter: int = 25
24 | # Resolution of UNet to compute attention maps over
25 | attention_res: int = 16
26 | # Whether to run standard SD or attend-and-excite
27 | run_standard_sd: bool = False
28 | # Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in
29 | thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8})
30 | # Scale factor for updating the denoised latent z_t
31 | scale_factor: int = 20
32 | # Start and end values used for scaling the scale factor - decays linearly with the denoising timestep
33 | scale_range: tuple = field(default_factory=lambda: (1.0, 0.5))
34 | # Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token
35 | smooth_attentions: bool = True
36 | # Standard deviation for the Gaussian smoothing
37 | sigma: float = 0.5
38 | # Kernel size for the Gaussian smoothing
39 | kernel_size: int = 3
40 | # Whether to save cross attention maps for the final results
41 | save_cross_attention_maps: bool = False
42 |
43 | # BoxDiff
44 | bbox: List[list] = field(default_factory=lambda: [[], []])
45 | color: List[str] = field(default_factory=lambda: ['blue', 'red', 'purple', 'orange', 'green', 'yellow', 'black'])
46 | P: float = 0.2
47 | # number of pixels around the corner to be selected
48 | L: int = 1
49 | refine: bool = True
50 | gligen_phrases: List[str] = field(default_factory=lambda: ['', ''])
51 | n_splits: int = 4
52 | which_one: int = 1
53 | eval_output_path: Path = Path('./outputs/eval')
54 |
55 |
56 | def __post_init__(self):
57 | self.output_path.mkdir(exist_ok=True, parents=True)
58 |
--------------------------------------------------------------------------------
/docs/boxdiff.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/docs/boxdiff.gif
--------------------------------------------------------------------------------
/docs/boxdiff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/docs/boxdiff.png
--------------------------------------------------------------------------------
/docs/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/docs/example.png
--------------------------------------------------------------------------------
/docs/visorgpt.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/docs/visorgpt.gif
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 |
2 | import pprint
3 | from typing import List
4 |
5 | import pyrallis
6 | import torch
7 | from PIL import Image
8 | from config import RunConfig
9 | from pipeline.gligen_pipeline_boxdiff import BoxDiffPipeline
10 | from utils import ptp_utils, vis_utils
11 | from utils.ptp_utils import AttentionStore
12 |
13 | import numpy as np
14 | from utils.drawer import draw_rectangle, DashedImageDraw
15 |
16 | import warnings
17 | import json, os
18 | from tqdm import tqdm
19 | warnings.filterwarnings("ignore", category=UserWarning)
20 |
21 |
22 | def load_model():
23 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
24 | stable_diffusion_version = "gligen/diffusers-generation-text-box"
25 | # If you cannot access the huggingface on your server, you can use the local prepared one.
26 | # stable_diffusion_version = "../../packages/diffusers/gligen_ckpts/diffusers-generation-text-box"
27 | stable = BoxDiffPipeline.from_pretrained(stable_diffusion_version, revision="fp16", torch_dtype=torch.float16).to(device)
28 |
29 | return stable
30 |
31 |
32 | def get_indices_to_alter(stable, prompt: str) -> List[int]:
33 | token_idx_to_word = {idx: stable.tokenizer.decode(t)
34 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids'])
35 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1}
36 | pprint.pprint(token_idx_to_word)
37 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to "
38 | "alter (e.g., 2,5): ")
39 | token_indices = [int(i) for i in token_indices.split(",")]
40 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}")
41 | return token_indices
42 |
43 |
44 | def run_on_prompt(prompt: List[str],
45 | model: BoxDiffPipeline,
46 | controller: AttentionStore,
47 | token_indices: List[int],
48 | seed: torch.Generator,
49 | config: RunConfig) -> Image.Image:
50 | if controller is not None:
51 | ptp_utils.register_attention_control(model, controller)
52 |
53 | gligen_boxes = []
54 | for i in range(len(config.bbox)):
55 | x1, y1, x2, y2 = config.bbox[i]
56 | gligen_boxes.append([x1/512, y1/512, x2/512, y2/512])
57 |
58 | outputs = model(prompt=prompt,
59 | attention_store=controller,
60 | indices_to_alter=token_indices,
61 | attention_res=config.attention_res,
62 | guidance_scale=config.guidance_scale,
63 | gligen_phrases=config.gligen_phrases,
64 | gligen_boxes=gligen_boxes,
65 | gligen_scheduled_sampling_beta=0.3,
66 | generator=seed,
67 | num_inference_steps=config.n_inference_steps,
68 | max_iter_to_alter=config.max_iter_to_alter,
69 | run_standard_sd=config.run_standard_sd,
70 | thresholds=config.thresholds,
71 | scale_factor=config.scale_factor,
72 | scale_range=config.scale_range,
73 | smooth_attentions=config.smooth_attentions,
74 | sigma=config.sigma,
75 | kernel_size=config.kernel_size,
76 | sd_2_1=config.sd_2_1,
77 | bbox=config.bbox,
78 | config=config)
79 | image = outputs.images[0]
80 | return image
81 |
82 |
83 | @pyrallis.wrap()
84 | def main(config: RunConfig):
85 | stable = load_model()
86 |
87 | # read bbox from the pre-prepared .json file
88 | with open('docs/bbox_as_condition.json', 'r', encoding='utf8') as fp:
89 | bbox_json = json.load(fp)
90 |
91 | idx = np.arange(len(bbox_json))
92 | split_idx = list(np.array_split(idx, config.n_splits)[config.which_one - 1])
93 |
94 | for bidx in tqdm(split_idx):
95 |
96 | filename = bbox_json[bidx]['filename']
97 | sub_dir = filename.split('/')[-2]
98 | img_name = filename.split('/')[-1].split('.')[0]
99 |
100 | objects = bbox_json[bidx]['objects']
101 | cls_name = list(objects.keys())
102 | config.bbox = list(objects.values())
103 | # import ipdb
104 | # ipdb.set_trace()
105 | text_prompt = ''
106 | gligen_phrases = []
107 | token_indices = []
108 | for nidx, n in enumerate(cls_name):
109 | text_prompt += f'a {n} and '
110 | if nidx == 0:
111 | token_indices.append(2)
112 | else:
113 | token_indices.append(5 + (nidx - 1) * 3)
114 | gligen_phrases.append('a {n}')
115 | config.prompt = text_prompt[:-5]
116 | config.gligen_phrases = gligen_phrases
117 |
118 | for seed in config.seeds:
119 | print(f"Current seed is : {seed}")
120 | g = torch.Generator('cuda').manual_seed(seed)
121 | controller = AttentionStore()
122 | controller.num_uncond_att_layers = -16
123 | image = run_on_prompt(prompt=config.prompt,
124 | model=stable,
125 | controller=controller,
126 | token_indices=token_indices,
127 | seed=g,
128 | config=config)
129 |
130 | prompt_output_path = config.eval_output_path / sub_dir
131 | prompt_output_path.mkdir(exist_ok=True, parents=True)
132 | if os.path.isfile(prompt_output_path / f'{img_name}_{seed}.png'):
133 | continue
134 |
135 | image.save(prompt_output_path / f'{img_name}_{seed}.png')
136 |
137 | canvas = Image.fromarray(np.zeros((image.size[0], image.size[0], 3), dtype=np.uint8) + 220)
138 | draw = DashedImageDraw(canvas)
139 |
140 | for i in range(len(config.bbox)):
141 | x1, y1, x2, y2 = config.bbox[i]
142 | draw.dashed_rectangle([(x1, y1), (x2, y2)], dash=(5, 5), outline=config.color[i], width=5)
143 | canvas.save(prompt_output_path / f'{img_name}_{seed}_canvas.png')
144 |
145 |
146 | if __name__ == '__main__':
147 | main()
148 |
--------------------------------------------------------------------------------
/pipeline/gligen_pipeline_boxdiff.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | from torch.nn import functional as F
7 |
8 | from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
9 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
10 |
11 | from diffusers.pipelines.stable_diffusion import StableDiffusionGLIGENPipeline
12 |
13 | from utils.gaussian_smoothing import GaussianSmoothing
14 | from utils.ptp_utils import AttentionStore, aggregate_attention
15 | import PIL
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 | class BoxDiffPipeline(StableDiffusionGLIGENPipeline):
20 | r"""
21 | Pipeline for text-to-image generation using Stable Diffusion.
22 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
23 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
24 | Args:
25 | vae ([`AutoencoderKL`]):
26 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
27 | text_encoder ([`CLIPTextModel`]):
28 | Frozen text-encoder. Stable Diffusion uses the text portion of
29 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
30 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
31 | tokenizer (`CLIPTokenizer`):
32 | Tokenizer of class
33 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
34 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
35 | scheduler ([`SchedulerMixin`]):
36 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
37 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
38 | safety_checker ([`StableDiffusionSafetyChecker`]):
39 | Classification module that estimates whether generated images could be considered offensive or harmful.
40 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
41 | feature_extractor ([`CLIPFeatureExtractor`]):
42 | Model that extracts features from generated images to be used as inputs for the `safety_checker`.
43 | """
44 | _optional_components = ["safety_checker", "feature_extractor"]
45 |
46 | def _encode_prompt(
47 | self,
48 | prompt,
49 | device,
50 | num_images_per_prompt,
51 | do_classifier_free_guidance,
52 | negative_prompt=None,
53 | prompt_embeds: Optional[torch.FloatTensor] = None,
54 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
55 | ):
56 | r"""
57 | Encodes the prompt into text encoder hidden states.
58 | Args:
59 | prompt (`str` or `List[str]`, *optional*):
60 | prompt to be encoded
61 | device: (`torch.device`):
62 | torch device
63 | num_images_per_prompt (`int`):
64 | number of images that should be generated per prompt
65 | do_classifier_free_guidance (`bool`):
66 | whether to use classifier free guidance or not
67 | negative_ prompt (`str` or `List[str]`, *optional*):
68 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
69 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
70 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
71 | prompt_embeds (`torch.FloatTensor`, *optional*):
72 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
73 | provided, text embeddings will be generated from `prompt` input argument.
74 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
75 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
76 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
77 | argument.
78 | """
79 | if prompt is not None and isinstance(prompt, str):
80 | batch_size = 1
81 | elif prompt is not None and isinstance(prompt, list):
82 | batch_size = len(prompt)
83 | else:
84 | batch_size = prompt_embeds.shape[0]
85 |
86 | if prompt_embeds is None:
87 | text_inputs = self.tokenizer(
88 | prompt,
89 | padding="max_length",
90 | max_length=self.tokenizer.model_max_length,
91 | truncation=True,
92 | return_tensors="pt",
93 | )
94 | text_input_ids = text_inputs.input_ids
95 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
96 |
97 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
98 | text_input_ids, untruncated_ids
99 | ):
100 | removed_text = self.tokenizer.batch_decode(
101 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
102 | )
103 | logger.warning(
104 | "The following part of your input was truncated because CLIP can only handle sequences up to"
105 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
106 | )
107 |
108 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
109 | attention_mask = text_inputs.attention_mask.to(device)
110 | else:
111 | attention_mask = None
112 |
113 | prompt_embeds = self.text_encoder(
114 | text_input_ids.to(device),
115 | attention_mask=attention_mask,
116 | )
117 | prompt_embeds = prompt_embeds[0]
118 |
119 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
120 |
121 | bs_embed, seq_len, _ = prompt_embeds.shape
122 | # duplicate text embeddings for each generation per prompt, using mps friendly method
123 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
124 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
125 |
126 | # get unconditional embeddings for classifier free guidance
127 | if do_classifier_free_guidance and negative_prompt_embeds is None:
128 | uncond_tokens: List[str]
129 | if negative_prompt is None:
130 | uncond_tokens = [""] * batch_size
131 | elif type(prompt) is not type(negative_prompt):
132 | raise TypeError(
133 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
134 | f" {type(prompt)}."
135 | )
136 | elif isinstance(negative_prompt, str):
137 | uncond_tokens = [negative_prompt]
138 | elif batch_size != len(negative_prompt):
139 | raise ValueError(
140 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
141 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
142 | " the batch size of `prompt`."
143 | )
144 | else:
145 | uncond_tokens = negative_prompt
146 |
147 | max_length = prompt_embeds.shape[1]
148 | uncond_input = self.tokenizer(
149 | uncond_tokens,
150 | padding="max_length",
151 | max_length=max_length,
152 | truncation=True,
153 | return_tensors="pt",
154 | )
155 |
156 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
157 | attention_mask = uncond_input.attention_mask.to(device)
158 | else:
159 | attention_mask = None
160 |
161 | negative_prompt_embeds = self.text_encoder(
162 | uncond_input.input_ids.to(device),
163 | attention_mask=attention_mask,
164 | )
165 | negative_prompt_embeds = negative_prompt_embeds[0]
166 |
167 | if do_classifier_free_guidance:
168 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
169 | seq_len = negative_prompt_embeds.shape[1]
170 |
171 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
172 |
173 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
174 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
175 |
176 | # For classifier free guidance, we need to do two forward passes.
177 | # Here we concatenate the unconditional and text embeddings into a single batch
178 | # to avoid doing two forward passes
179 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
180 |
181 | return text_inputs, prompt_embeds
182 |
183 | def _compute_max_attention_per_index(self,
184 | attention_maps: torch.Tensor,
185 | indices_to_alter: List[int],
186 | smooth_attentions: bool = False,
187 | sigma: float = 0.5,
188 | kernel_size: int = 3,
189 | normalize_eot: bool = False,
190 | bbox: List[int] = None,
191 | config=None,
192 | ) -> List[torch.Tensor]:
193 | """ Computes the maximum attention value for each of the tokens we wish to alter. """
194 | last_idx = -1
195 | if normalize_eot:
196 | prompt = self.prompt
197 | if isinstance(self.prompt, list):
198 | prompt = self.prompt[0]
199 | last_idx = len(self.tokenizer(prompt)['input_ids']) - 1
200 | attention_for_text = attention_maps[:, :, 1:last_idx]
201 | attention_for_text *= 100
202 | attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)
203 |
204 | # Shift indices since we removed the first token
205 | indices_to_alter = [index - 1 for index in indices_to_alter]
206 |
207 | # Extract the maximum values
208 | max_indices_list_fg = []
209 | max_indices_list_bg = []
210 | dist_x = []
211 | dist_y = []
212 |
213 | cnt = 0
214 | for i in indices_to_alter:
215 | image = attention_for_text[:, :, i]
216 |
217 | box = [max(round(b / (512 / image.shape[0])), 0) for b in bbox[cnt]]
218 | x1, y1, x2, y2 = box
219 | cnt += 1
220 |
221 | # coordinates to masks
222 | obj_mask = torch.zeros_like(image)
223 | ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device)
224 | obj_mask[y1:y2, x1:x2] = ones_mask
225 | bg_mask = 1 - obj_mask
226 |
227 | if smooth_attentions:
228 | smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
229 | input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
230 | image = smoothing(input).squeeze(0).squeeze(0)
231 |
232 | # Inner-Box constraint
233 | k = (obj_mask.sum() * config.P).long()
234 | max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean())
235 |
236 | # Outer-Box constraint
237 | k = (bg_mask.sum() * config.P).long()
238 | max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean())
239 |
240 | # Corner Constraint
241 | gt_proj_x = torch.max(obj_mask, dim=0)[0]
242 | gt_proj_y = torch.max(obj_mask, dim=1)[0]
243 | corner_mask_x = torch.zeros_like(gt_proj_x)
244 | corner_mask_y = torch.zeros_like(gt_proj_y)
245 |
246 | # create gt according to the number config.L
247 | N = gt_proj_x.shape[0]
248 | corner_mask_x[max(box[0] - config.L, 0): min(box[0] + config.L + 1, N)] = 1.
249 | corner_mask_x[max(box[2] - config.L, 0): min(box[2] + config.L + 1, N)] = 1.
250 | corner_mask_y[max(box[1] - config.L, 0): min(box[1] + config.L + 1, N)] = 1.
251 | corner_mask_y[max(box[3] - config.L, 0): min(box[3] + config.L + 1, N)] = 1.
252 | dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction='none') * corner_mask_x).mean())
253 | dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction='none') * corner_mask_y).mean())
254 |
255 | return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y
256 |
257 | def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionStore,
258 | indices_to_alter: List[int],
259 | attention_res: int = 16,
260 | smooth_attentions: bool = False,
261 | sigma: float = 0.5,
262 | kernel_size: int = 3,
263 | normalize_eot: bool = False,
264 | bbox: List[int] = None,
265 | config=None,
266 | ):
267 | """ Aggregates the attention for each token and computes the max activation value for each token to alter. """
268 | attention_maps = aggregate_attention(
269 | attention_store=attention_store,
270 | res=attention_res,
271 | from_where=("up", "down", "mid"),
272 | is_cross=True,
273 | select=0)
274 |
275 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index(
276 | attention_maps=attention_maps,
277 | indices_to_alter=indices_to_alter,
278 | smooth_attentions=smooth_attentions,
279 | sigma=sigma,
280 | kernel_size=kernel_size,
281 | normalize_eot=normalize_eot,
282 | bbox=bbox,
283 | config=config,
284 | )
285 | return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y
286 |
287 | @staticmethod
288 | def _compute_loss(max_attention_per_index_fg: List[torch.Tensor], max_attention_per_index_bg: List[torch.Tensor],
289 | dist_x: List[torch.Tensor], dist_y: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor:
290 | """ Computes the attend-and-excite loss using the maximum attention value for each token. """
291 | losses_fg = [max(0, 1. - curr_max) for curr_max in max_attention_per_index_fg]
292 | losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg]
293 | loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y)
294 | if return_losses:
295 | return max(losses_fg), losses_fg
296 | else:
297 | return max(losses_fg), loss
298 |
299 | @staticmethod
300 | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor:
301 | """ Update the latent according to the computed loss. """
302 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0]
303 | latents = latents - step_size * grad_cond
304 | return latents
305 |
306 | def _perform_iterative_refinement_step(self,
307 | latents: torch.Tensor,
308 | indices_to_alter: List[int],
309 | loss_fg: torch.Tensor,
310 | threshold: float,
311 | text_embeddings: torch.Tensor,
312 | text_input,
313 | attention_store: AttentionStore,
314 | step_size: float,
315 | t: int,
316 | attention_res: int = 16,
317 | smooth_attentions: bool = True,
318 | sigma: float = 0.5,
319 | kernel_size: int = 3,
320 | max_refinement_steps: int = 20,
321 | normalize_eot: bool = False,
322 | bbox: List[int] = None,
323 | config=None,
324 | ):
325 | """
326 | Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent
327 | code according to our loss objective until the given threshold is reached for all tokens.
328 | """
329 | iteration = 0
330 | target_loss = max(0, 1. - threshold)
331 | while loss_fg > target_loss:
332 | iteration += 1
333 |
334 | latents = latents.clone().detach().requires_grad_(True)
335 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
336 | self.unet.zero_grad()
337 |
338 | # Get max activation value for each subject token
339 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token(
340 | attention_store=attention_store,
341 | indices_to_alter=indices_to_alter,
342 | attention_res=attention_res,
343 | smooth_attentions=smooth_attentions,
344 | sigma=sigma,
345 | kernel_size=kernel_size,
346 | normalize_eot=normalize_eot,
347 | bbox=bbox,
348 | config=config,
349 | )
350 |
351 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True)
352 |
353 | if loss_fg != 0:
354 | latents = self._update_latent(latents, loss_fg, step_size)
355 |
356 | with torch.no_grad():
357 | noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample
358 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
359 |
360 | try:
361 | low_token = np.argmax([l.item() if type(l) != int else l for l in losses_fg])
362 | except Exception as e:
363 | print(e) # catch edge case :)
364 |
365 | low_token = np.argmax(losses_fg)
366 |
367 | low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]])
368 | # print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index_fg[low_token]}')
369 |
370 | if iteration >= max_refinement_steps:
371 | # print(f'\t Exceeded max number of iterations ({max_refinement_steps})! '
372 | # f'Finished with a max attention of {max_attention_per_index_fg[low_token]}')
373 | break
374 |
375 | # Run one more time but don't compute gradients and update the latents.
376 | # We just need to compute the new loss - the grad update will occur below
377 | latents = latents.clone().detach().requires_grad_(True)
378 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
379 | self.unet.zero_grad()
380 |
381 | # Get max activation value for each subject token
382 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token(
383 | attention_store=attention_store,
384 | indices_to_alter=indices_to_alter,
385 | attention_res=attention_res,
386 | smooth_attentions=smooth_attentions,
387 | sigma=sigma,
388 | kernel_size=kernel_size,
389 | normalize_eot=normalize_eot,
390 | bbox=bbox,
391 | config=config,
392 | )
393 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True)
394 | # print(f"\t Finished with loss of: {loss_fg}")
395 | return loss_fg, latents, max_attention_per_index_fg
396 |
397 | @torch.no_grad()
398 | def __call__(
399 | self,
400 | prompt: Union[str, List[str]],
401 | attention_store: AttentionStore,
402 | indices_to_alter: List[int],
403 | attention_res: int = 16,
404 | height: Optional[int] = None,
405 | width: Optional[int] = None,
406 | num_inference_steps: int = 50,
407 | guidance_scale: float = 7.5,
408 | gligen_scheduled_sampling_beta: float = 0.3,
409 | gligen_phrases: List[str] = None,
410 | gligen_boxes: List[List[float]] = None,
411 | gligen_inpaint_image: Optional[PIL.Image.Image] = None,
412 | negative_prompt: Optional[Union[str, List[str]]] = None,
413 | num_images_per_prompt: Optional[int] = 1,
414 | eta: float = 0.0,
415 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
416 | latents: Optional[torch.FloatTensor] = None,
417 | prompt_embeds: Optional[torch.FloatTensor] = None,
418 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
419 | output_type: Optional[str] = "pil",
420 | return_dict: bool = True,
421 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
422 | callback_steps: Optional[int] = 1,
423 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
424 | max_iter_to_alter: Optional[int] = 25,
425 | run_standard_sd: bool = False,
426 | thresholds: Optional[dict] = {0: 0.05, 10: 0.5, 20: 0.8},
427 | scale_factor: int = 20,
428 | scale_range: Tuple[float, float] = (1., 0.5),
429 | smooth_attentions: bool = True,
430 | sigma: float = 0.5,
431 | kernel_size: int = 3,
432 | sd_2_1: bool = False,
433 | bbox: List[int] = None,
434 | config = None,
435 | ):
436 | r"""
437 | Function invoked when calling the pipeline for generation.
438 | Args:
439 | prompt (`str` or `List[str]`, *optional*):
440 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
441 | instead.
442 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
443 | The height in pixels of the generated image.
444 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
445 | The width in pixels of the generated image.
446 | num_inference_steps (`int`, *optional*, defaults to 50):
447 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
448 | expense of slower inference.
449 | guidance_scale (`float`, *optional*, defaults to 7.5):
450 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
451 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
452 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
453 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
454 | usually at the expense of lower image quality.
455 | negative_prompt (`str` or `List[str]`, *optional*):
456 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
457 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
458 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
459 | num_images_per_prompt (`int`, *optional*, defaults to 1):
460 | The number of images to generate per prompt.
461 | eta (`float`, *optional*, defaults to 0.0):
462 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
463 | [`schedulers.DDIMScheduler`], will be ignored for others.
464 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
465 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
466 | to make generation deterministic.
467 | latents (`torch.FloatTensor`, *optional*):
468 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
469 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
470 | tensor will ge generated by sampling using the supplied random `generator`.
471 | prompt_embeds (`torch.FloatTensor`, *optional*):
472 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
473 | provided, text embeddings will be generated from `prompt` input argument.
474 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
475 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
476 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
477 | argument.
478 | output_type (`str`, *optional*, defaults to `"pil"`):
479 | The output format of the generate image. Choose between
480 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
481 | return_dict (`bool`, *optional*, defaults to `True`):
482 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
483 | plain tuple.
484 | callback (`Callable`, *optional*):
485 | A function that will be called every `callback_steps` steps during inference. The function will be
486 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
487 | callback_steps (`int`, *optional*, defaults to 1):
488 | The frequency at which the `callback` function will be called. If not specified, the callback will be
489 | called at every step.
490 | cross_attention_kwargs (`dict`, *optional*):
491 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
492 | `self.processor` in
493 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
494 | Examples:
495 | Returns:
496 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
497 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
498 | When returning a tuple, the first element is a list with the generated images, and the second element is a
499 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
500 | (nsfw) content, according to the `safety_checker`.
501 | :type attention_store: object
502 | """
503 | # 0. Default height and width to unet
504 | height = height or self.unet.config.sample_size * self.vae_scale_factor
505 | width = width or self.unet.config.sample_size * self.vae_scale_factor
506 |
507 | # 1. Check inputs. Raise error if not correct
508 | self.check_inputs(
509 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
510 | )
511 |
512 | # 2. Define call parameters
513 | self.prompt = prompt
514 | if prompt is not None and isinstance(prompt, str):
515 | batch_size = 1
516 | elif prompt is not None and isinstance(prompt, list):
517 | batch_size = len(prompt)
518 | else:
519 | batch_size = prompt_embeds.shape[0]
520 |
521 | device = self._execution_device
522 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
523 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
524 | # corresponds to doing no classifier free guidance.
525 | do_classifier_free_guidance = guidance_scale > 1.0
526 |
527 | # 3. Encode input prompt
528 | text_inputs, prompt_embeds = self._encode_prompt(
529 | prompt,
530 | device,
531 | num_images_per_prompt,
532 | do_classifier_free_guidance,
533 | negative_prompt,
534 | prompt_embeds=prompt_embeds,
535 | negative_prompt_embeds=negative_prompt_embeds,
536 | )
537 |
538 | # 4. Prepare timesteps
539 | self.scheduler.set_timesteps(num_inference_steps, device=device)
540 | timesteps = self.scheduler.timesteps
541 |
542 | # 5. Prepare latent variables
543 | num_channels_latents = self.unet.in_channels
544 | latents = self.prepare_latents(
545 | batch_size * num_images_per_prompt,
546 | num_channels_latents,
547 | height,
548 | width,
549 | prompt_embeds.dtype,
550 | device,
551 | generator,
552 | latents,
553 | )
554 |
555 | def draw_inpaint_mask_from_boxes(boxes, size):
556 | inpaint_mask = torch.ones(size[0], size[1])
557 | for box in boxes:
558 | x0, x1 = box[0] * size[0], box[2] * size[0]
559 | y0, y1 = box[1] * size[1], box[3] * size[1]
560 | inpaint_mask[int(y0):int(y1), int(x0):int(x1)] = 0
561 | return inpaint_mask
562 |
563 | # 5.1 Prepare GLIGEN variables
564 | if gligen_phrases is not None:
565 | assert len(gligen_phrases) == len(gligen_boxes)
566 | assert batch_size == 1
567 | max_objs = 30
568 | _boxes = gligen_boxes
569 | tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(
570 | self.text_encoder.device)
571 | _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output
572 | n_objs = min(len(_boxes), max_objs)
573 | device = self.text_encoder.device
574 | dtype = self.text_encoder.dtype
575 | boxes = torch.zeros(max_objs, 4, device=device, dtype=dtype)
576 | boxes[:n_objs] = torch.tensor(_boxes[:n_objs])
577 | text_embeddings = torch.zeros(max_objs, 768, device=device, dtype=dtype)
578 | text_embeddings[:n_objs] = _text_embeddings[:n_objs]
579 | masks = torch.zeros(max_objs, device=device, dtype=dtype)
580 | masks[:n_objs] = 1
581 |
582 | repeat_batch = batch_size * num_images_per_prompt
583 | if do_classifier_free_guidance:
584 | repeat_batch = repeat_batch * 2
585 | boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
586 | text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
587 | masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
588 | if do_classifier_free_guidance:
589 | masks[:repeat_batch // 2] = 0
590 | if cross_attention_kwargs is None:
591 | cross_attention_kwargs = {}
592 | cross_attention_kwargs['gligen'] = {
593 | 'boxes': boxes,
594 | 'positive_embeddings': text_embeddings,
595 | 'masks': masks
596 | }
597 |
598 | # Prepare latent variables for GLIGEN inpainting
599 | if gligen_inpaint_image is not None:
600 | if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size):
601 | def crop(im, new_width, new_height):
602 | width, height = im.size
603 | left = (width - new_width) / 2
604 | top = (height - new_height) / 2
605 | right = (width + new_width) / 2
606 | bottom = (height + new_height) / 2
607 | return im.crop((left, top, right, bottom))
608 |
609 | def target_size_center_crop(im, new_hw):
610 | width, height = im.size
611 | if width != height:
612 | im = crop(im, min(height, width), min(height, width))
613 | return im.resize((new_hw, new_hw), PIL.Image.LANCZOS)
614 |
615 | gligen_inpaint_image = target_size_center_crop(gligen_inpaint_image, self.vae.sample_size)
616 |
617 | gligen_inpaint_image = torch.from_numpy(np.asarray(gligen_inpaint_image))
618 | gligen_inpaint_image = gligen_inpaint_image.unsqueeze(0).permute(0, 3, 1, 2)
619 | gligen_inpaint_image = gligen_inpaint_image.to(dtype=torch.float32) / 127.5 - 1.0
620 | gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device)
621 | gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample()
622 | gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent
623 | gligen_inpaint_mask = draw_inpaint_mask_from_boxes(_boxes[:n_objs], gligen_inpaint_latent.shape[2:])
624 | gligen_inpaint_mask = gligen_inpaint_mask.to(dtype=gligen_inpaint_latent.dtype,
625 | device=gligen_inpaint_latent.device)
626 | gligen_inpaint_mask = gligen_inpaint_mask[None, None]
627 | gligen_inpaint_mask_addition = torch.cat(
628 | (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1)
629 | gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone()
630 |
631 | num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
632 | self.enable_fuser(True)
633 |
634 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
635 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
636 |
637 | scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps))
638 |
639 | if max_iter_to_alter is None:
640 | max_iter_to_alter = len(self.scheduler.timesteps) + 1
641 |
642 | # 7. Denoising loop
643 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
644 | with self.progress_bar(total=num_inference_steps) as progress_bar:
645 | for i, t in enumerate(timesteps):
646 |
647 | with torch.enable_grad():
648 |
649 | if i == num_grounding_steps:
650 | self.enable_fuser(False)
651 |
652 | if latents.shape[1] != 4:
653 | latents = torch.randn_like(latents[:, :4])
654 |
655 | if gligen_inpaint_image is not None:
656 | gligen_inpaint_latent_with_noise = self.scheduler.add_noise(
657 | gligen_inpaint_latent,
658 | torch.randn_like(gligen_inpaint_latent),
659 | t
660 | ).expand(latents.shape[0], -1, -1, -1).clone()
661 | latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * (
662 | 1 - gligen_inpaint_mask)
663 |
664 | latents = latents.clone().detach().requires_grad_(True)
665 |
666 | # expand the latents if we are doing classifier free guidance
667 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
668 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
669 |
670 | if gligen_inpaint_image is not None:
671 | latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1)
672 |
673 | # Forward pass of denoising with text conditioning
674 | noise_pred_text = self.unet(latent_model_input, t,
675 | encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs).sample
676 | self.unet.zero_grad()
677 |
678 | # Get max activation value for each subject token
679 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token(
680 | attention_store=attention_store,
681 | indices_to_alter=indices_to_alter,
682 | attention_res=attention_res,
683 | smooth_attentions=smooth_attentions,
684 | sigma=sigma,
685 | kernel_size=kernel_size,
686 | normalize_eot=sd_2_1,
687 | bbox=bbox,
688 | config=config,
689 | )
690 |
691 | if not run_standard_sd:
692 |
693 | loss_fg, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y)
694 |
695 | # Refinement from attend-and-excite (not necessary)
696 | if i in thresholds.keys() and loss_fg > 1. - thresholds[i] and config.refine:
697 | del noise_pred_text
698 | torch.cuda.empty_cache()
699 | loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step(
700 | latents=latents,
701 | indices_to_alter=indices_to_alter,
702 | loss_fg=loss_fg,
703 | threshold=thresholds[i],
704 | text_embeddings=prompt_embeds,
705 | text_input=text_inputs,
706 | attention_store=attention_store,
707 | step_size=scale_factor * np.sqrt(scale_range[i]),
708 | t=t,
709 | attention_res=attention_res,
710 | smooth_attentions=smooth_attentions,
711 | sigma=sigma,
712 | kernel_size=kernel_size,
713 | normalize_eot=sd_2_1,
714 | bbox=bbox,
715 | config=config,
716 | )
717 |
718 | # Perform gradient update
719 | if i < max_iter_to_alter:
720 | _, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y)
721 | if loss != 0:
722 | latents = self._update_latent(latents=latents, loss=loss,
723 | step_size=scale_factor * np.sqrt(scale_range[i]))
724 |
725 | # print(f'Iteration {i} | Loss: {loss:0.4f}')
726 |
727 | # predict the noise residual
728 | noise_pred = self.unet(
729 | latent_model_input,
730 | t,
731 | encoder_hidden_states=prompt_embeds,
732 | cross_attention_kwargs=cross_attention_kwargs,
733 | ).sample
734 |
735 | # perform guidance
736 | if do_classifier_free_guidance:
737 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
738 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
739 |
740 | # compute the previous noisy sample x_t -> x_t-1
741 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
742 |
743 | # call the callback, if provided
744 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
745 | progress_bar.update()
746 | if callback is not None and i % callback_steps == 0:
747 | callback(i, t, latents)
748 |
749 | # 8. Post-processing
750 | image = self.decode_latents(latents)
751 |
752 | # 9. Run safety checker
753 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
754 |
755 | # 10. Convert to PIL
756 | if output_type == "pil":
757 | image = self.numpy_to_pil(image)
758 |
759 | if not return_dict:
760 | return (image, has_nsfw_concept)
761 |
762 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
763 |
--------------------------------------------------------------------------------
/pipeline/sd_pipeline_boxdiff.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | from torch.nn import functional as F
7 |
8 | from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
9 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
10 |
11 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
12 |
13 | from utils.gaussian_smoothing import GaussianSmoothing
14 | from utils.ptp_utils import AttentionStore, aggregate_attention
15 |
16 | logger = logging.get_logger(__name__)
17 |
18 | class BoxDiffPipeline(StableDiffusionPipeline):
19 | r"""
20 | Pipeline for text-to-image generation using Stable Diffusion.
21 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
22 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
23 | Args:
24 | vae ([`AutoencoderKL`]):
25 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
26 | text_encoder ([`CLIPTextModel`]):
27 | Frozen text-encoder. Stable Diffusion uses the text portion of
28 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
29 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
30 | tokenizer (`CLIPTokenizer`):
31 | Tokenizer of class
32 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
33 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
34 | scheduler ([`SchedulerMixin`]):
35 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
36 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
37 | safety_checker ([`StableDiffusionSafetyChecker`]):
38 | Classification module that estimates whether generated images could be considered offensive or harmful.
39 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
40 | feature_extractor ([`CLIPFeatureExtractor`]):
41 | Model that extracts features from generated images to be used as inputs for the `safety_checker`.
42 | """
43 | _optional_components = ["safety_checker", "feature_extractor"]
44 |
45 | def _encode_prompt(
46 | self,
47 | prompt,
48 | device,
49 | num_images_per_prompt,
50 | do_classifier_free_guidance,
51 | negative_prompt=None,
52 | prompt_embeds: Optional[torch.FloatTensor] = None,
53 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
54 | ):
55 | r"""
56 | Encodes the prompt into text encoder hidden states.
57 | Args:
58 | prompt (`str` or `List[str]`, *optional*):
59 | prompt to be encoded
60 | device: (`torch.device`):
61 | torch device
62 | num_images_per_prompt (`int`):
63 | number of images that should be generated per prompt
64 | do_classifier_free_guidance (`bool`):
65 | whether to use classifier free guidance or not
66 | negative_ prompt (`str` or `List[str]`, *optional*):
67 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
68 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
69 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
70 | prompt_embeds (`torch.FloatTensor`, *optional*):
71 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
72 | provided, text embeddings will be generated from `prompt` input argument.
73 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
74 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
75 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
76 | argument.
77 | """
78 | if prompt is not None and isinstance(prompt, str):
79 | batch_size = 1
80 | elif prompt is not None and isinstance(prompt, list):
81 | batch_size = len(prompt)
82 | else:
83 | batch_size = prompt_embeds.shape[0]
84 |
85 | if prompt_embeds is None:
86 | text_inputs = self.tokenizer(
87 | prompt,
88 | padding="max_length",
89 | max_length=self.tokenizer.model_max_length,
90 | truncation=True,
91 | return_tensors="pt",
92 | )
93 | text_input_ids = text_inputs.input_ids
94 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
95 |
96 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
97 | text_input_ids, untruncated_ids
98 | ):
99 | removed_text = self.tokenizer.batch_decode(
100 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
101 | )
102 | logger.warning(
103 | "The following part of your input was truncated because CLIP can only handle sequences up to"
104 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
105 | )
106 |
107 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
108 | attention_mask = text_inputs.attention_mask.to(device)
109 | else:
110 | attention_mask = None
111 |
112 | prompt_embeds = self.text_encoder(
113 | text_input_ids.to(device),
114 | attention_mask=attention_mask,
115 | )
116 | prompt_embeds = prompt_embeds[0]
117 |
118 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
119 |
120 | bs_embed, seq_len, _ = prompt_embeds.shape
121 | # duplicate text embeddings for each generation per prompt, using mps friendly method
122 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
123 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
124 |
125 | # get unconditional embeddings for classifier free guidance
126 | if do_classifier_free_guidance and negative_prompt_embeds is None:
127 | uncond_tokens: List[str]
128 | if negative_prompt is None:
129 | uncond_tokens = [""] * batch_size
130 | elif type(prompt) is not type(negative_prompt):
131 | raise TypeError(
132 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
133 | f" {type(prompt)}."
134 | )
135 | elif isinstance(negative_prompt, str):
136 | uncond_tokens = [negative_prompt]
137 | elif batch_size != len(negative_prompt):
138 | raise ValueError(
139 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
140 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
141 | " the batch size of `prompt`."
142 | )
143 | else:
144 | uncond_tokens = negative_prompt
145 |
146 | max_length = prompt_embeds.shape[1]
147 | uncond_input = self.tokenizer(
148 | uncond_tokens,
149 | padding="max_length",
150 | max_length=max_length,
151 | truncation=True,
152 | return_tensors="pt",
153 | )
154 |
155 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
156 | attention_mask = uncond_input.attention_mask.to(device)
157 | else:
158 | attention_mask = None
159 |
160 | negative_prompt_embeds = self.text_encoder(
161 | uncond_input.input_ids.to(device),
162 | attention_mask=attention_mask,
163 | )
164 | negative_prompt_embeds = negative_prompt_embeds[0]
165 |
166 | if do_classifier_free_guidance:
167 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
168 | seq_len = negative_prompt_embeds.shape[1]
169 |
170 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
171 |
172 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
173 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
174 |
175 | # For classifier free guidance, we need to do two forward passes.
176 | # Here we concatenate the unconditional and text embeddings into a single batch
177 | # to avoid doing two forward passes
178 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
179 |
180 | return text_inputs, prompt_embeds
181 |
182 | def _compute_max_attention_per_index(self,
183 | attention_maps: torch.Tensor,
184 | indices_to_alter: List[int],
185 | smooth_attentions: bool = False,
186 | sigma: float = 0.5,
187 | kernel_size: int = 3,
188 | normalize_eot: bool = False,
189 | bbox: List[int] = None,
190 | config=None,
191 | ) -> List[torch.Tensor]:
192 | """ Computes the maximum attention value for each of the tokens we wish to alter. """
193 | last_idx = -1
194 | if normalize_eot:
195 | prompt = self.prompt
196 | if isinstance(self.prompt, list):
197 | prompt = self.prompt[0]
198 | last_idx = len(self.tokenizer(prompt)['input_ids']) - 1
199 | attention_for_text = attention_maps[:, :, 1:last_idx]
200 | attention_for_text *= 100
201 | attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)
202 |
203 | # Shift indices since we removed the first token
204 | indices_to_alter = [index - 1 for index in indices_to_alter]
205 |
206 | # Extract the maximum values
207 | max_indices_list_fg = []
208 | max_indices_list_bg = []
209 | dist_x = []
210 | dist_y = []
211 |
212 | cnt = 0
213 | for i in indices_to_alter:
214 | image = attention_for_text[:, :, i]
215 |
216 | box = [max(round(b / (512 / image.shape[0])), 0) for b in bbox[cnt]]
217 | x1, y1, x2, y2 = box
218 | cnt += 1
219 |
220 | # coordinates to masks
221 | obj_mask = torch.zeros_like(image)
222 | ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device)
223 | obj_mask[y1:y2, x1:x2] = ones_mask
224 | bg_mask = 1 - obj_mask
225 |
226 | if smooth_attentions:
227 | smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
228 | input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
229 | image = smoothing(input).squeeze(0).squeeze(0)
230 |
231 | # Inner-Box constraint
232 | k = (obj_mask.sum() * config.P).long()
233 | max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean())
234 |
235 | # Outer-Box constraint
236 | k = (bg_mask.sum() * config.P).long()
237 | max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean())
238 |
239 | # Corner Constraint
240 | gt_proj_x = torch.max(obj_mask, dim=0)[0]
241 | gt_proj_y = torch.max(obj_mask, dim=1)[0]
242 | corner_mask_x = torch.zeros_like(gt_proj_x)
243 | corner_mask_y = torch.zeros_like(gt_proj_y)
244 |
245 | # create gt according to the number config.L
246 | N = gt_proj_x.shape[0]
247 | corner_mask_x[max(box[0] - config.L, 0): min(box[0] + config.L + 1, N)] = 1.
248 | corner_mask_x[max(box[2] - config.L, 0): min(box[2] + config.L + 1, N)] = 1.
249 | corner_mask_y[max(box[1] - config.L, 0): min(box[1] + config.L + 1, N)] = 1.
250 | corner_mask_y[max(box[3] - config.L, 0): min(box[3] + config.L + 1, N)] = 1.
251 | dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction='none') * corner_mask_x).mean())
252 | dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction='none') * corner_mask_y).mean())
253 |
254 | return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y
255 |
256 | def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionStore,
257 | indices_to_alter: List[int],
258 | attention_res: int = 16,
259 | smooth_attentions: bool = False,
260 | sigma: float = 0.5,
261 | kernel_size: int = 3,
262 | normalize_eot: bool = False,
263 | bbox: List[int] = None,
264 | config=None,
265 | ):
266 | """ Aggregates the attention for each token and computes the max activation value for each token to alter. """
267 | attention_maps = aggregate_attention(
268 | attention_store=attention_store,
269 | res=attention_res,
270 | from_where=("up", "down", "mid"),
271 | is_cross=True,
272 | select=0)
273 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index(
274 | attention_maps=attention_maps,
275 | indices_to_alter=indices_to_alter,
276 | smooth_attentions=smooth_attentions,
277 | sigma=sigma,
278 | kernel_size=kernel_size,
279 | normalize_eot=normalize_eot,
280 | bbox=bbox,
281 | config=config,
282 | )
283 | return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y
284 |
285 | @staticmethod
286 | def _compute_loss(max_attention_per_index_fg: List[torch.Tensor], max_attention_per_index_bg: List[torch.Tensor],
287 | dist_x: List[torch.Tensor], dist_y: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor:
288 | """ Computes the attend-and-excite loss using the maximum attention value for each token. """
289 | losses_fg = [max(0, 1. - curr_max) for curr_max in max_attention_per_index_fg]
290 | losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg]
291 | loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y)
292 | if return_losses:
293 | return max(losses_fg), losses_fg
294 | else:
295 | return max(losses_fg), loss
296 |
297 | @staticmethod
298 | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor:
299 | """ Update the latent according to the computed loss. """
300 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0]
301 | latents = latents - step_size * grad_cond
302 | return latents
303 |
304 | def _perform_iterative_refinement_step(self,
305 | latents: torch.Tensor,
306 | indices_to_alter: List[int],
307 | loss_fg: torch.Tensor,
308 | threshold: float,
309 | text_embeddings: torch.Tensor,
310 | text_input,
311 | attention_store: AttentionStore,
312 | step_size: float,
313 | t: int,
314 | attention_res: int = 16,
315 | smooth_attentions: bool = True,
316 | sigma: float = 0.5,
317 | kernel_size: int = 3,
318 | max_refinement_steps: int = 20,
319 | normalize_eot: bool = False,
320 | bbox: List[int] = None,
321 | config=None,
322 | ):
323 | """
324 | Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent
325 | code according to our loss objective until the given threshold is reached for all tokens.
326 | """
327 | iteration = 0
328 | target_loss = max(0, 1. - threshold)
329 | while loss_fg > target_loss:
330 | iteration += 1
331 |
332 | latents = latents.clone().detach().requires_grad_(True)
333 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
334 | self.unet.zero_grad()
335 |
336 | # Get max activation value for each subject token
337 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token(
338 | attention_store=attention_store,
339 | indices_to_alter=indices_to_alter,
340 | attention_res=attention_res,
341 | smooth_attentions=smooth_attentions,
342 | sigma=sigma,
343 | kernel_size=kernel_size,
344 | normalize_eot=normalize_eot,
345 | bbox=bbox,
346 | config=config,
347 | )
348 |
349 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True)
350 |
351 | if loss_fg != 0:
352 | latents = self._update_latent(latents, loss_fg, step_size)
353 |
354 | with torch.no_grad():
355 | noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample
356 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
357 |
358 | try:
359 | low_token = np.argmax([l.item() if type(l) != int else l for l in losses_fg])
360 | except Exception as e:
361 | print(e) # catch edge case :)
362 |
363 | low_token = np.argmax(losses_fg)
364 |
365 | low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]])
366 | # print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index_fg[low_token]}')
367 |
368 | if iteration >= max_refinement_steps:
369 | # print(f'\t Exceeded max number of iterations ({max_refinement_steps})! '
370 | # f'Finished with a max attention of {max_attention_per_index_fg[low_token]}')
371 | break
372 |
373 | # Run one more time but don't compute gradients and update the latents.
374 | # We just need to compute the new loss - the grad update will occur below
375 | latents = latents.clone().detach().requires_grad_(True)
376 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
377 | self.unet.zero_grad()
378 |
379 | # Get max activation value for each subject token
380 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token(
381 | attention_store=attention_store,
382 | indices_to_alter=indices_to_alter,
383 | attention_res=attention_res,
384 | smooth_attentions=smooth_attentions,
385 | sigma=sigma,
386 | kernel_size=kernel_size,
387 | normalize_eot=normalize_eot,
388 | bbox=bbox,
389 | config=config,
390 | )
391 | loss_fg, losses_fg = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True)
392 | # print(f"\t Finished with loss of: {loss_fg}")
393 | return loss_fg, latents, max_attention_per_index_fg
394 |
395 | @torch.no_grad()
396 | def __call__(
397 | self,
398 | prompt: Union[str, List[str]],
399 | attention_store: AttentionStore,
400 | indices_to_alter: List[int],
401 | attention_res: int = 16,
402 | height: Optional[int] = None,
403 | width: Optional[int] = None,
404 | num_inference_steps: int = 50,
405 | guidance_scale: float = 7.5,
406 | negative_prompt: Optional[Union[str, List[str]]] = None,
407 | num_images_per_prompt: Optional[int] = 1,
408 | eta: float = 0.0,
409 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
410 | latents: Optional[torch.FloatTensor] = None,
411 | prompt_embeds: Optional[torch.FloatTensor] = None,
412 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
413 | output_type: Optional[str] = "pil",
414 | return_dict: bool = True,
415 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
416 | callback_steps: Optional[int] = 1,
417 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
418 | max_iter_to_alter: Optional[int] = 25,
419 | run_standard_sd: bool = False,
420 | thresholds: Optional[dict] = {0: 0.05, 10: 0.5, 20: 0.8},
421 | scale_factor: int = 20,
422 | scale_range: Tuple[float, float] = (1., 0.5),
423 | smooth_attentions: bool = True,
424 | sigma: float = 0.5,
425 | kernel_size: int = 3,
426 | sd_2_1: bool = False,
427 | bbox: List[int] = None,
428 | config = None,
429 | ):
430 | r"""
431 | Function invoked when calling the pipeline for generation.
432 | Args:
433 | prompt (`str` or `List[str]`, *optional*):
434 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
435 | instead.
436 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
437 | The height in pixels of the generated image.
438 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
439 | The width in pixels of the generated image.
440 | num_inference_steps (`int`, *optional*, defaults to 50):
441 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
442 | expense of slower inference.
443 | guidance_scale (`float`, *optional*, defaults to 7.5):
444 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
445 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
446 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
447 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
448 | usually at the expense of lower image quality.
449 | negative_prompt (`str` or `List[str]`, *optional*):
450 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
451 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
452 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
453 | num_images_per_prompt (`int`, *optional*, defaults to 1):
454 | The number of images to generate per prompt.
455 | eta (`float`, *optional*, defaults to 0.0):
456 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
457 | [`schedulers.DDIMScheduler`], will be ignored for others.
458 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
459 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
460 | to make generation deterministic.
461 | latents (`torch.FloatTensor`, *optional*):
462 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
463 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
464 | tensor will ge generated by sampling using the supplied random `generator`.
465 | prompt_embeds (`torch.FloatTensor`, *optional*):
466 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
467 | provided, text embeddings will be generated from `prompt` input argument.
468 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
469 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
470 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
471 | argument.
472 | output_type (`str`, *optional*, defaults to `"pil"`):
473 | The output format of the generate image. Choose between
474 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
475 | return_dict (`bool`, *optional*, defaults to `True`):
476 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
477 | plain tuple.
478 | callback (`Callable`, *optional*):
479 | A function that will be called every `callback_steps` steps during inference. The function will be
480 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
481 | callback_steps (`int`, *optional*, defaults to 1):
482 | The frequency at which the `callback` function will be called. If not specified, the callback will be
483 | called at every step.
484 | cross_attention_kwargs (`dict`, *optional*):
485 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
486 | `self.processor` in
487 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
488 | Examples:
489 | Returns:
490 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
491 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
492 | When returning a tuple, the first element is a list with the generated images, and the second element is a
493 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
494 | (nsfw) content, according to the `safety_checker`.
495 | :type attention_store: object
496 | """
497 | # 0. Default height and width to unet
498 | height = height or self.unet.config.sample_size * self.vae_scale_factor
499 | width = width or self.unet.config.sample_size * self.vae_scale_factor
500 |
501 | # 1. Check inputs. Raise error if not correct
502 | self.check_inputs(
503 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
504 | )
505 |
506 | # 2. Define call parameters
507 | self.prompt = prompt
508 | if prompt is not None and isinstance(prompt, str):
509 | batch_size = 1
510 | elif prompt is not None and isinstance(prompt, list):
511 | batch_size = len(prompt)
512 | else:
513 | batch_size = prompt_embeds.shape[0]
514 |
515 | device = self._execution_device
516 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
517 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
518 | # corresponds to doing no classifier free guidance.
519 | do_classifier_free_guidance = guidance_scale > 1.0
520 |
521 | # 3. Encode input prompt
522 | text_inputs, prompt_embeds = self._encode_prompt(
523 | prompt,
524 | device,
525 | num_images_per_prompt,
526 | do_classifier_free_guidance,
527 | negative_prompt,
528 | prompt_embeds=prompt_embeds,
529 | negative_prompt_embeds=negative_prompt_embeds,
530 | )
531 |
532 | # 4. Prepare timesteps
533 | self.scheduler.set_timesteps(num_inference_steps, device=device)
534 | timesteps = self.scheduler.timesteps
535 |
536 | # 5. Prepare latent variables
537 | num_channels_latents = self.unet.in_channels
538 | latents = self.prepare_latents(
539 | batch_size * num_images_per_prompt,
540 | num_channels_latents,
541 | height,
542 | width,
543 | prompt_embeds.dtype,
544 | device,
545 | generator,
546 | latents,
547 | )
548 |
549 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
550 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
551 |
552 | scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps))
553 |
554 | if max_iter_to_alter is None:
555 | max_iter_to_alter = len(self.scheduler.timesteps) + 1
556 |
557 | # 7. Denoising loop
558 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
559 | with self.progress_bar(total=num_inference_steps) as progress_bar:
560 | for i, t in enumerate(timesteps):
561 |
562 | with torch.enable_grad():
563 |
564 | latents = latents.clone().detach().requires_grad_(True)
565 |
566 | # Forward pass of denoising with text conditioning
567 | noise_pred_text = self.unet(latents, t,
568 | encoder_hidden_states=prompt_embeds[1].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs).sample
569 | self.unet.zero_grad()
570 |
571 | # Get max activation value for each subject token
572 | max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._aggregate_and_get_max_attention_per_token(
573 | attention_store=attention_store,
574 | indices_to_alter=indices_to_alter,
575 | attention_res=attention_res,
576 | smooth_attentions=smooth_attentions,
577 | sigma=sigma,
578 | kernel_size=kernel_size,
579 | normalize_eot=sd_2_1,
580 | bbox=bbox,
581 | config=config,
582 | )
583 |
584 | if not run_standard_sd:
585 |
586 | loss_fg, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y)
587 |
588 | # Refinement from attend-and-excite (not necessary)
589 | if i in thresholds.keys() and loss_fg > 1. - thresholds[i] and config.refine:
590 | del noise_pred_text
591 | torch.cuda.empty_cache()
592 | loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step(
593 | latents=latents,
594 | indices_to_alter=indices_to_alter,
595 | loss_fg=loss_fg,
596 | threshold=thresholds[i],
597 | text_embeddings=prompt_embeds,
598 | text_input=text_inputs,
599 | attention_store=attention_store,
600 | step_size=scale_factor * np.sqrt(scale_range[i]),
601 | t=t,
602 | attention_res=attention_res,
603 | smooth_attentions=smooth_attentions,
604 | sigma=sigma,
605 | kernel_size=kernel_size,
606 | normalize_eot=sd_2_1,
607 | bbox=bbox,
608 | config=config,
609 | )
610 |
611 | # Perform gradient update
612 | if i < max_iter_to_alter:
613 | _, loss = self._compute_loss(max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y)
614 | if loss != 0:
615 | latents = self._update_latent(latents=latents, loss=loss,
616 | step_size=scale_factor * np.sqrt(scale_range[i]))
617 |
618 | # print(f'Iteration {i} | Loss: {loss:0.4f}')
619 |
620 | # expand the latents if we are doing classifier free guidance
621 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
622 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
623 |
624 | # predict the noise residual
625 | noise_pred = self.unet(
626 | latent_model_input,
627 | t,
628 | encoder_hidden_states=prompt_embeds,
629 | cross_attention_kwargs=cross_attention_kwargs,
630 | ).sample
631 |
632 | # perform guidance
633 | if do_classifier_free_guidance:
634 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
635 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
636 |
637 | # compute the previous noisy sample x_t -> x_t-1
638 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
639 |
640 | # call the callback, if provided
641 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
642 | progress_bar.update()
643 | if callback is not None and i % callback_steps == 0:
644 | callback(i, t, latents)
645 |
646 | # 8. Post-processing
647 | image = self.decode_latents(latents)
648 |
649 | # 9. Run safety checker
650 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
651 |
652 | # 10. Convert to PIL
653 | if output_type == "pil":
654 | image = self.numpy_to_pil(image)
655 |
656 | if not return_dict:
657 | return (image, has_nsfw_concept)
658 |
659 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
660 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | opencv-python
3 | ipywidgets
4 | matplotlib
5 | pyrallis
6 | torch==1.12.0
7 | torchvision==0.13.0
8 | transformers==4.26.0
9 | jupyter
10 | accelerate
11 | IPython
12 |
--------------------------------------------------------------------------------
/run_gligen_boxdiff.py:
--------------------------------------------------------------------------------
1 |
2 | import pprint
3 | from typing import List
4 |
5 | import pyrallis
6 | import torch
7 | from PIL import Image
8 | from config import RunConfig
9 | from pipeline.gligen_pipeline_boxdiff import BoxDiffPipeline
10 | from utils import ptp_utils, vis_utils
11 | from utils.ptp_utils import AttentionStore
12 |
13 | import numpy as np
14 | from utils.drawer import draw_rectangle, DashedImageDraw
15 |
16 | import warnings
17 | warnings.filterwarnings("ignore", category=UserWarning)
18 |
19 |
20 | def load_model():
21 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
22 | stable_diffusion_version = "gligen/diffusers-generation-text-box"
23 | # If you cannot access the huggingface on your server, you can use the local prepared one.
24 | # stable_diffusion_version = "../../packages/diffusers/gligen_ckpts/diffusers-generation-text-box"
25 | stable = BoxDiffPipeline.from_pretrained(stable_diffusion_version, revision="fp16", torch_dtype=torch.float16).to(device)
26 |
27 | return stable
28 |
29 |
30 | def get_indices_to_alter(stable, prompt: str) -> List[int]:
31 | token_idx_to_word = {idx: stable.tokenizer.decode(t)
32 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids'])
33 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1}
34 | pprint.pprint(token_idx_to_word)
35 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to "
36 | "alter (e.g., 2,5): ")
37 | token_indices = [int(i) for i in token_indices.split(",")]
38 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}")
39 | return token_indices
40 |
41 |
42 | def run_on_prompt(prompt: List[str],
43 | model: BoxDiffPipeline,
44 | controller: AttentionStore,
45 | token_indices: List[int],
46 | seed: torch.Generator,
47 | config: RunConfig) -> Image.Image:
48 | if controller is not None:
49 | ptp_utils.register_attention_control(model, controller)
50 |
51 | gligen_boxes = []
52 | for i in range(len(config.bbox)):
53 | x1, y1, x2, y2 = config.bbox[i]
54 | gligen_boxes.append([x1/512, y1/512, x2/512, y2/512])
55 |
56 | outputs = model(prompt=prompt,
57 | attention_store=controller,
58 | indices_to_alter=token_indices,
59 | attention_res=config.attention_res,
60 | guidance_scale=config.guidance_scale,
61 | gligen_phrases=config.gligen_phrases,
62 | gligen_boxes=gligen_boxes,
63 | gligen_scheduled_sampling_beta=0.3,
64 | generator=seed,
65 | num_inference_steps=config.n_inference_steps,
66 | max_iter_to_alter=config.max_iter_to_alter,
67 | run_standard_sd=config.run_standard_sd,
68 | thresholds=config.thresholds,
69 | scale_factor=config.scale_factor,
70 | scale_range=config.scale_range,
71 | smooth_attentions=config.smooth_attentions,
72 | sigma=config.sigma,
73 | kernel_size=config.kernel_size,
74 | sd_2_1=config.sd_2_1,
75 | bbox=config.bbox,
76 | config=config)
77 | image = outputs.images[0]
78 | return image
79 |
80 |
81 | @pyrallis.wrap()
82 | def main(config: RunConfig):
83 | stable = load_model()
84 | token_indices = get_indices_to_alter(stable, config.prompt) if config.token_indices is None else config.token_indices
85 |
86 | if len(config.bbox[0]) == 0:
87 | config.bbox = draw_rectangle()
88 |
89 | images = []
90 | for seed in config.seeds:
91 | print(f"Current seed is : {seed}")
92 | g = torch.Generator('cuda').manual_seed(seed)
93 | controller = AttentionStore()
94 | controller.num_uncond_att_layers = -16
95 | image = run_on_prompt(prompt=config.prompt,
96 | model=stable,
97 | controller=controller,
98 | token_indices=token_indices,
99 | seed=g,
100 | config=config)
101 | prompt_output_path = config.output_path / config.prompt[:100]
102 | prompt_output_path.mkdir(exist_ok=True, parents=True)
103 | image.save(prompt_output_path / f'{seed}.png')
104 | images.append(image)
105 |
106 | canvas = Image.fromarray(np.zeros((image.size[0], image.size[0], 3), dtype=np.uint8) + 220)
107 | draw = DashedImageDraw(canvas)
108 |
109 | for i in range(len(config.bbox)):
110 | x1, y1, x2, y2 = config.bbox[i]
111 | draw.dashed_rectangle([(x1, y1), (x2, y2)], dash=(5, 5), outline=config.color[i], width=5)
112 | canvas.save(prompt_output_path / f'{seed}_canvas.png')
113 |
114 | # save a grid of results across all seeds
115 | joined_image = vis_utils.get_image_grid(images)
116 | joined_image.save(config.output_path / f'{config.prompt}.png')
117 |
118 |
119 | if __name__ == '__main__':
120 | main()
121 |
--------------------------------------------------------------------------------
/run_sd_boxdiff.py:
--------------------------------------------------------------------------------
1 |
2 | import pprint
3 | from typing import List
4 |
5 | import pyrallis
6 | import torch
7 | from PIL import Image
8 | from config import RunConfig
9 | from pipeline.sd_pipeline_boxdiff import BoxDiffPipeline
10 | from utils import ptp_utils, vis_utils
11 | from utils.ptp_utils import AttentionStore
12 |
13 | import numpy as np
14 | from utils.drawer import draw_rectangle, DashedImageDraw
15 |
16 | import warnings
17 | warnings.filterwarnings("ignore", category=UserWarning)
18 |
19 |
20 | def load_model(config: RunConfig):
21 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
22 |
23 | if config.sd_2_1:
24 | stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
25 | else:
26 | stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
27 | # If you cannot access the huggingface on your server, you can use the local prepared one.
28 | # stable_diffusion_version = "../../packages/huggingface/hub/stable-diffusion-v1-4"
29 | stable = BoxDiffPipeline.from_pretrained(stable_diffusion_version).to(device)
30 |
31 | return stable
32 |
33 |
34 | def get_indices_to_alter(stable, prompt: str) -> List[int]:
35 | token_idx_to_word = {idx: stable.tokenizer.decode(t)
36 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids'])
37 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1}
38 | pprint.pprint(token_idx_to_word)
39 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to "
40 | "alter (e.g., 2,5): ")
41 | token_indices = [int(i) for i in token_indices.split(",")]
42 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}")
43 | return token_indices
44 |
45 |
46 | def run_on_prompt(prompt: List[str],
47 | model: BoxDiffPipeline,
48 | controller: AttentionStore,
49 | token_indices: List[int],
50 | seed: torch.Generator,
51 | config: RunConfig) -> Image.Image:
52 | if controller is not None:
53 | ptp_utils.register_attention_control(model, controller)
54 | outputs = model(prompt=prompt,
55 | attention_store=controller,
56 | indices_to_alter=token_indices,
57 | attention_res=config.attention_res,
58 | guidance_scale=config.guidance_scale,
59 | generator=seed,
60 | num_inference_steps=config.n_inference_steps,
61 | max_iter_to_alter=config.max_iter_to_alter,
62 | run_standard_sd=config.run_standard_sd,
63 | thresholds=config.thresholds,
64 | scale_factor=config.scale_factor,
65 | scale_range=config.scale_range,
66 | smooth_attentions=config.smooth_attentions,
67 | sigma=config.sigma,
68 | kernel_size=config.kernel_size,
69 | sd_2_1=config.sd_2_1,
70 | bbox=config.bbox,
71 | config=config)
72 | image = outputs.images[0]
73 | return image
74 |
75 |
76 | @pyrallis.wrap()
77 | def main(config: RunConfig):
78 | stable = load_model(config)
79 | token_indices = get_indices_to_alter(stable, config.prompt) if config.token_indices is None else config.token_indices
80 |
81 | if len(config.bbox[0]) == 0:
82 | config.bbox = draw_rectangle()
83 |
84 | images = []
85 | for seed in config.seeds:
86 | print(f"Current seed is : {seed}")
87 | g = torch.Generator('cuda').manual_seed(seed)
88 | controller = AttentionStore()
89 | image = run_on_prompt(prompt=config.prompt,
90 | model=stable,
91 | controller=controller,
92 | token_indices=token_indices,
93 | seed=g,
94 | config=config)
95 | prompt_output_path = config.output_path / config.prompt[:100]
96 | prompt_output_path.mkdir(exist_ok=True, parents=True)
97 | image.save(prompt_output_path / f'{seed}.png')
98 | images.append(image)
99 |
100 | canvas = Image.fromarray(np.zeros((image.size[0], image.size[0], 3), dtype=np.uint8) + 220)
101 | draw = DashedImageDraw(canvas)
102 |
103 | for i in range(len(config.bbox)):
104 | x1, y1, x2, y2 = config.bbox[i]
105 | draw.dashed_rectangle([(x1, y1), (x2, y2)], dash=(5, 5), outline=config.color[i], width=5)
106 | canvas.save(prompt_output_path / f'{seed}_canvas.png')
107 |
108 | # save a grid of results across all seeds
109 | joined_image = vis_utils.get_image_grid(images)
110 | joined_image.save(config.output_path / f'{config.prompt}.png')
111 |
112 |
113 | if __name__ == '__main__':
114 | main()
115 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/BoxDiff/b0d5d3b534418aa3fc71b9a16e5b575c0b2ee3b6/utils/__init__.py
--------------------------------------------------------------------------------
/utils/drawer.py:
--------------------------------------------------------------------------------
1 | from tkinter import *
2 | from PIL import ImageDraw as D
3 |
4 | import math
5 | class DashedImageDraw(D.ImageDraw):
6 |
7 | def thick_line(self, xy, direction, fill=None, width=0):
8 |
9 | if xy[0] != xy[1]:
10 | self.line(xy, fill=fill, width=width)
11 | else:
12 | x1, y1 = xy[0]
13 | dx1, dy1 = direction[0]
14 | dx2, dy2 = direction[1]
15 | if dy2 - dy1 < 0:
16 | x1 -= 1
17 | if dx2 - dx1 < 0:
18 | y1 -= 1
19 | if dy2 - dy1 != 0:
20 | if dx2 - dx1 != 0:
21 | k = - (dx2 - dx1) / (dy2 - dy1)
22 | a = 1 / math.sqrt(1 + k ** 2)
23 | b = (width * a - 1) / 2
24 | else:
25 | k = 0
26 | b = (width - 1) / 2
27 | x3 = x1 - math.floor(b)
28 | y3 = y1 - int(k * b)
29 | x4 = x1 + math.ceil(b)
30 | y4 = y1 + int(k * b)
31 | else:
32 | x3 = x1
33 | y3 = y1 - math.floor((width - 1) / 2)
34 | x4 = x1
35 | y4 = y1 + math.ceil((width - 1) / 2)
36 | self.line([(x3, y3), (x4, y4)], fill=fill, width=1)
37 | return
38 |
39 | def dashed_line(self, xy, dash=(2, 2), fill=None, width=0):
40 | for i in range(len(xy) - 1):
41 | x1, y1 = xy[i]
42 | x2, y2 = xy[i + 1]
43 | x_length = x2 - x1
44 | y_length = y2 - y1
45 | length = math.sqrt(x_length ** 2 + y_length ** 2)
46 | dash_enabled = True
47 | postion = 0
48 | while postion <= length:
49 | for dash_step in dash:
50 | if postion > length:
51 | break
52 | if dash_enabled:
53 | start = postion / length
54 | end = min((postion + dash_step - 1) / length, 1)
55 | self.thick_line([(round(x1 + start * x_length),
56 | round(y1 + start * y_length)),
57 | (round(x1 + end * x_length),
58 | round(y1 + end * y_length))],
59 | xy, fill, width)
60 | dash_enabled = not dash_enabled
61 | postion += dash_step
62 | return
63 |
64 | def dashed_rectangle(self, xy, dash=(2, 2), outline=None, width=0):
65 | x1, y1 = xy[0]
66 | x2, y2 = xy[1]
67 | halfwidth1 = math.floor((width - 1) / 2)
68 | halfwidth2 = math.ceil((width - 1) / 2)
69 | min_dash_gap = min(dash[1::2])
70 | end_change1 = halfwidth1 + min_dash_gap + 1
71 | end_change2 = halfwidth2 + min_dash_gap + 1
72 | odd_width_change = (width - 1) % 2
73 | self.dashed_line([(x1 - halfwidth1, y1), (x2 - end_change1, y1)],
74 | dash, outline, width)
75 | self.dashed_line([(x2, y1 - halfwidth1), (x2, y2 - end_change1)],
76 | dash, outline, width)
77 | self.dashed_line([(x2 + halfwidth2, y2 + odd_width_change),
78 | (x1 + end_change2, y2 + odd_width_change)],
79 | dash, outline, width)
80 | self.dashed_line([(x1 + odd_width_change, y2 + halfwidth2),
81 | (x1 + odd_width_change, y1 + end_change2)],
82 | dash, outline, width)
83 | return
84 |
85 | class RectangleDrawer:
86 | def __init__(self, master):
87 | self.master = master
88 | width, height = 512, 512
89 | self.canvas = Canvas(self.master, bg='#F0FFF0', width=width, height=height)
90 | self.canvas.pack()
91 |
92 | self.rectangles = []
93 | self.colors = ['blue', 'red', 'purple', 'orange', 'green', 'yellow', 'black']
94 |
95 | self.canvas.bind("", self.on_button_press)
96 | self.canvas.bind("", self.on_move_press)
97 | self.canvas.bind("", self.on_button_release)
98 | self.start_x = None
99 | self.start_y = None
100 | self.cur_rect = None
101 | self.master.update()
102 | width = self.master.winfo_width()
103 | height = self.master.winfo_height()
104 | x = (self.master.winfo_screenwidth() // 2) - (width // 2)
105 | y = (self.master.winfo_screenheight() // 2) - (height // 2)
106 | self.master.geometry('{}x{}+{}+{}'.format(width, height, x, y))
107 |
108 |
109 | def on_button_press(self, event):
110 | self.start_x = event.x
111 | self.start_y = event.y
112 | self.cur_rect = self.canvas.create_rectangle(self.start_x, self.start_y, self.start_x, self.start_y, outline=self.colors[len(self.rectangles)%len(self.colors)], width=5, dash=(4, 4))
113 |
114 | def on_move_press(self, event):
115 | cur_x, cur_y = (event.x, event.y)
116 | self.canvas.coords(self.cur_rect, self.start_x, self.start_y, cur_x, cur_y)
117 |
118 | def on_button_release(self, event):
119 | cur_x, cur_y = (event.x, event.y)
120 | self.rectangles.append([self.start_x, self.start_y, cur_x, cur_y])
121 | self.cur_rect = None
122 |
123 | def get_rectangles(self):
124 | return self.rectangles
125 |
126 |
127 | def draw_rectangle():
128 | root = Tk()
129 | root.title("Rectangle Drawer")
130 |
131 | drawer = RectangleDrawer(root)
132 |
133 | def on_enter_press(event):
134 | root.quit()
135 |
136 | root.bind('', on_enter_press)
137 |
138 | root.mainloop()
139 | rectangles = drawer.get_rectangles()
140 |
141 | new_rects = []
142 | for r in rectangles:
143 | new_rects.extend(r)
144 |
145 | return new_rects
146 |
147 | if __name__ == '__main__':
148 | root = Tk()
149 | root.title("Rectangle Drawer")
150 |
151 | drawer = RectangleDrawer(root)
152 |
153 | def on_enter_press(event):
154 | root.quit()
155 |
156 | root.bind('', on_enter_press)
157 |
158 | root.mainloop()
159 | rectangles = drawer.get_rectangles()
160 |
161 | string = '['
162 | for r in rectangles:
163 | string += '['
164 | for n in r:
165 | string += str(n)
166 | string += ','
167 | string = string[:-1]
168 | string += '],'
169 | string = string[:-1]
170 | string += ']'
171 | print("Rectangles:", string)
--------------------------------------------------------------------------------
/utils/gaussian_smoothing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class GaussianSmoothing(nn.Module):
9 | """
10 | Apply gaussian smoothing on a
11 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
12 | in the input using a depthwise convolution.
13 | Arguments:
14 | channels (int, sequence): Number of channels of the input tensors. Output will
15 | have this number of channels as well.
16 | kernel_size (int, sequence): Size of the gaussian kernel.
17 | sigma (float, sequence): Standard deviation of the gaussian kernel.
18 | dim (int, optional): The number of dimensions of the data.
19 | Default value is 2 (spatial).
20 | """
21 | def __init__(self, channels, kernel_size, sigma, dim=2):
22 | super(GaussianSmoothing, self).__init__()
23 | if isinstance(kernel_size, numbers.Number):
24 | kernel_size = [kernel_size] * dim
25 | if isinstance(sigma, numbers.Number):
26 | sigma = [sigma] * dim
27 |
28 | # The gaussian kernel is the product of the
29 | # gaussian function of each dimension.
30 | kernel = 1
31 | meshgrids = torch.meshgrid(
32 | [
33 | torch.arange(size, dtype=torch.float32)
34 | for size in kernel_size
35 | ]
36 | )
37 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
38 | mean = (size - 1) / 2
39 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
40 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
41 |
42 | # Make sure sum of values in gaussian kernel equals 1.
43 | kernel = kernel / torch.sum(kernel)
44 |
45 | # Reshape to depthwise convolutional weight
46 | kernel = kernel.view(1, 1, *kernel.size())
47 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
48 |
49 | self.register_buffer('weight', kernel)
50 | self.groups = channels
51 |
52 | if dim == 1:
53 | self.conv = F.conv1d
54 | elif dim == 2:
55 | self.conv = F.conv2d
56 | elif dim == 3:
57 | self.conv = F.conv3d
58 | else:
59 | raise RuntimeError(
60 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
61 | )
62 |
63 | def forward(self, input):
64 | """
65 | Apply gaussian filter to input.
66 | Arguments:
67 | input (torch.Tensor): Input to apply gaussian filter on.
68 | Returns:
69 | filtered (torch.Tensor): Filtered output.
70 | """
71 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
72 |
73 |
74 | class AverageSmoothing(nn.Module):
75 | """
76 | Apply average smoothing on a
77 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
78 | in the input using a depthwise convolution.
79 | Arguments:
80 | channels (int, sequence): Number of channels of the input tensors. Output will
81 | have this number of channels as well.
82 | kernel_size (int, sequence): Size of the average kernel.
83 | sigma (float, sequence): Standard deviation of the rage kernel.
84 | dim (int, optional): The number of dimensions of the data.
85 | Default value is 2 (spatial).
86 | """
87 | def __init__(self, channels, kernel_size, dim=2):
88 | super(AverageSmoothing, self).__init__()
89 |
90 | # Make sure sum of values in gaussian kernel equals 1.
91 | kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size)
92 |
93 | # Reshape to depthwise convolutional weight
94 | kernel = kernel.view(1, 1, *kernel.size())
95 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
96 |
97 | self.register_buffer('weight', kernel)
98 | self.groups = channels
99 |
100 | if dim == 1:
101 | self.conv = F.conv1d
102 | elif dim == 2:
103 | self.conv = F.conv2d
104 | elif dim == 3:
105 | self.conv = F.conv3d
106 | else:
107 | raise RuntimeError(
108 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
109 | )
110 |
111 | def forward(self, input):
112 | """
113 | Apply average filter to input.
114 | Arguments:
115 | input (torch.Tensor): Input to apply average filter on.
116 | Returns:
117 | filtered (torch.Tensor): Filtered output.
118 | """
119 | return self.conv(input, weight=self.weight, groups=self.groups)
120 |
--------------------------------------------------------------------------------
/utils/ptp_utils.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from IPython.display import display
7 | from PIL import Image
8 | from typing import Union, Tuple, List
9 |
10 | from diffusers.models.cross_attention import CrossAttention
11 |
12 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray:
13 | h, w, c = image.shape
14 | offset = int(h * .2)
15 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
16 | font = cv2.FONT_HERSHEY_SIMPLEX
17 | img[:h] = image
18 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
19 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
20 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
21 | return img
22 |
23 |
24 | def view_images(images: Union[np.ndarray, List],
25 | num_rows: int = 1,
26 | offset_ratio: float = 0.02,
27 | display_image: bool = True) -> Image.Image:
28 | """ Displays a list of images in a grid. """
29 | if type(images) is list:
30 | num_empty = len(images) % num_rows
31 | elif images.ndim == 4:
32 | num_empty = images.shape[0] % num_rows
33 | else:
34 | images = [images]
35 | num_empty = 0
36 |
37 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
38 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
39 | num_items = len(images)
40 |
41 | h, w, c = images[0].shape
42 | offset = int(h * offset_ratio)
43 | num_cols = num_items // num_rows
44 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
45 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
46 | for i in range(num_rows):
47 | for j in range(num_cols):
48 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
49 | i * num_cols + j]
50 |
51 | pil_img = Image.fromarray(image_)
52 | if display_image:
53 | display(pil_img)
54 | return pil_img
55 |
56 |
57 | class AttendExciteCrossAttnProcessor:
58 |
59 | def __init__(self, attnstore, place_in_unet):
60 | super().__init__()
61 | self.attnstore = attnstore
62 | self.place_in_unet = place_in_unet
63 |
64 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
65 | batch_size, sequence_length, _ = hidden_states.shape
66 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=1)
67 | query = attn.to_q(hidden_states)
68 |
69 | is_cross = encoder_hidden_states is not None
70 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
71 | key = attn.to_k(encoder_hidden_states)
72 | value = attn.to_v(encoder_hidden_states)
73 |
74 | query = attn.head_to_batch_dim(query)
75 | key = attn.head_to_batch_dim(key)
76 | value = attn.head_to_batch_dim(value)
77 |
78 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
79 |
80 | self.attnstore(attention_probs, is_cross, self.place_in_unet)
81 |
82 | hidden_states = torch.bmm(attention_probs, value)
83 | hidden_states = attn.batch_to_head_dim(hidden_states)
84 |
85 | # linear proj
86 | hidden_states = attn.to_out[0](hidden_states)
87 | # dropout
88 | hidden_states = attn.to_out[1](hidden_states)
89 |
90 | return hidden_states
91 |
92 |
93 | def register_attention_control(model, controller):
94 |
95 | attn_procs = {}
96 | cross_att_count = 0
97 | for name in model.unet.attn_processors.keys():
98 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
99 | if name.startswith("mid_block"):
100 | hidden_size = model.unet.config.block_out_channels[-1]
101 | place_in_unet = "mid"
102 | elif name.startswith("up_blocks"):
103 | block_id = int(name[len("up_blocks.")])
104 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
105 | place_in_unet = "up"
106 | elif name.startswith("down_blocks"):
107 | block_id = int(name[len("down_blocks.")])
108 | hidden_size = model.unet.config.block_out_channels[block_id]
109 | place_in_unet = "down"
110 | else:
111 | continue
112 |
113 | cross_att_count += 1
114 | attn_procs[name] = AttendExciteCrossAttnProcessor(
115 | attnstore=controller, place_in_unet=place_in_unet
116 | )
117 | model.unet.set_attn_processor(attn_procs)
118 | controller.num_att_layers = cross_att_count
119 |
120 | class AttentionControl(abc.ABC):
121 |
122 | def step_callback(self, x_t):
123 | return x_t
124 |
125 | def between_steps(self):
126 | return
127 |
128 | # @property
129 | # def num_uncond_att_layers(self):
130 | # return 0
131 |
132 | @abc.abstractmethod
133 | def forward(self, attn, is_cross: bool, place_in_unet: str):
134 | raise NotImplementedError
135 |
136 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
137 | if self.cur_att_layer >= self.num_uncond_att_layers:
138 | self.forward(attn, is_cross, place_in_unet)
139 | self.cur_att_layer += 1
140 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
141 | self.cur_att_layer = 0
142 | self.cur_step += 1
143 | self.between_steps()
144 |
145 | def reset(self):
146 | self.cur_step = 0
147 | self.cur_att_layer = 0
148 |
149 | def __init__(self):
150 | self.cur_step = 0
151 | self.num_att_layers = -1
152 | self.cur_att_layer = 0
153 |
154 |
155 | class EmptyControl(AttentionControl):
156 |
157 | def forward(self, attn, is_cross: bool, place_in_unet: str):
158 | return attn
159 |
160 |
161 | class AttentionStore(AttentionControl):
162 |
163 | @staticmethod
164 | def get_empty_store():
165 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
166 | "down_self": [], "mid_self": [], "up_self": []}
167 |
168 | def forward(self, attn, is_cross: bool, place_in_unet: str):
169 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
170 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead
171 | self.step_store[key].append(attn)
172 | return attn
173 |
174 | def between_steps(self):
175 | self.attention_store = self.step_store
176 | if self.save_global_store:
177 | with torch.no_grad():
178 | if len(self.global_store) == 0:
179 | self.global_store = self.step_store
180 | else:
181 | for key in self.global_store:
182 | for i in range(len(self.global_store[key])):
183 | self.global_store[key][i] += self.step_store[key][i].detach()
184 | self.step_store = self.get_empty_store()
185 | self.step_store = self.get_empty_store()
186 |
187 | def get_average_attention(self):
188 | average_attention = self.attention_store
189 | return average_attention
190 |
191 | def get_average_global_attention(self):
192 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
193 | self.attention_store}
194 | return average_attention
195 |
196 | def reset(self):
197 | super(AttentionStore, self).reset()
198 | self.step_store = self.get_empty_store()
199 | self.attention_store = {}
200 | self.global_store = {}
201 |
202 | def __init__(self, save_global_store=False):
203 | '''
204 | Initialize an empty AttentionStore
205 | :param step_index: used to visualize only a specific step in the diffusion process
206 | '''
207 | super(AttentionStore, self).__init__()
208 | self.save_global_store = save_global_store
209 | self.step_store = self.get_empty_store()
210 | self.attention_store = {}
211 | self.global_store = {}
212 | self.curr_step_index = 0
213 | self.num_uncond_att_layers = 0
214 |
215 |
216 | def aggregate_attention(attention_store: AttentionStore,
217 | res: int,
218 | from_where: List[str],
219 | is_cross: bool,
220 | select: int) -> torch.Tensor:
221 | """ Aggregates the attention across the different layers and heads at the specified resolution. """
222 | out = []
223 | attention_maps = attention_store.get_average_attention()
224 |
225 | num_pixels = res ** 2
226 | for location in from_where:
227 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
228 | if item.shape[1] == num_pixels:
229 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select]
230 | out.append(cross_maps)
231 | out = torch.cat(out, dim=0)
232 | out = out.sum(0) / out.shape[0]
233 | return out
234 |
--------------------------------------------------------------------------------
/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List
3 | from PIL import Image
4 | import cv2
5 | import numpy as np
6 | import torch
7 |
8 | from utils import ptp_utils
9 | from utils.ptp_utils import AttentionStore, aggregate_attention
10 |
11 |
12 | def show_cross_attention(prompt: str,
13 | attention_store: AttentionStore,
14 | tokenizer,
15 | indices_to_alter: List[int],
16 | res: int,
17 | from_where: List[str],
18 | select: int = 0,
19 | orig_image=None):
20 | tokens = tokenizer.encode(prompt)
21 | decoder = tokenizer.decode
22 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select).detach().cpu()
23 | images = []
24 |
25 | # show spatial attention for indices of tokens to strengthen
26 | for i in range(len(tokens)):
27 | image = attention_maps[:, :, i]
28 | if i in indices_to_alter:
29 | image = show_image_relevance(image, orig_image)
30 | image = image.astype(np.uint8)
31 | image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2)))
32 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
33 | images.append(image)
34 |
35 | ptp_utils.view_images(np.stack(images, axis=0))
36 |
37 |
38 | def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16):
39 | # create heatmap from mask on image
40 | def show_cam_on_image(img, mask):
41 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
42 | heatmap = np.float32(heatmap) / 255
43 | cam = heatmap + np.float32(img)
44 | cam = cam / np.max(cam)
45 | return cam
46 |
47 | image = image.resize((relevnace_res ** 2, relevnace_res ** 2))
48 | image = np.array(image)
49 |
50 | image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1])
51 | image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu
52 | image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear')
53 | image_relevance = image_relevance.cpu() # send it back to cpu
54 | image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
55 | image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2)
56 | image = (image - image.min()) / (image.max() - image.min())
57 | vis = show_cam_on_image(image, image_relevance)
58 | vis = np.uint8(255 * vis)
59 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
60 | return vis
61 |
62 |
63 | def get_image_grid(images: List[Image.Image]) -> Image:
64 | num_images = len(images)
65 | cols = int(math.ceil(math.sqrt(num_images)))
66 | rows = int(math.ceil(num_images / cols))
67 | width, height = images[0].size
68 | grid_image = Image.new('RGB', (cols * width, rows * height))
69 | for i, img in enumerate(images):
70 | x = i % cols
71 | y = i // cols
72 | grid_image.paste(img, (x * width, y * height))
73 | return grid_image
74 |
--------------------------------------------------------------------------------