├── LICENSE
├── README.md
├── config.py
├── docs
├── explainability.jpg
├── results.jpg
└── teaser.jpg
├── environment
├── environment.yaml
└── requirements.txt
├── metrics
├── __init__.py
├── blip_captioning_and_clip_similarity.py
├── compute_clip_similarity.py
└── imagenet_utils.py
├── notebooks
├── explain.ipynb
└── generate_images.ipynb
├── pipeline_attend_and_excite.py
├── run.py
└── utils
├── __init__.py
├── gaussian_smoothing.py
├── ptp_utils.py
└── vis_utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 AttendAndExcite
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models (SIGGRAPH 2023)
2 |
3 | > Hila Chefer*, Yuval Alaluf*, Yael Vinker, Lior Wolf, Daniel Cohen-Or
4 | > Tel Aviv University
5 | > \* Denotes equal contribution
6 | >
7 | > Recent text-to-image generative models have demonstrated an unparalleled ability to generate diverse and creative imagery guided by a target text prompt. While revolutionary, current state-of-the-art diffusion models may still fail in generating images that fully convey the semantics in the given text prompt. We analyze the publicly available Stable Diffusion model and assess the existence of catastrophic neglect, where the model fails to generate one or more of the subjects from the input prompt. Moreover, we find that in some cases the model also fails to correctly bind attributes (e.g., colors) to their corresponding subjects. To help mitigate these failure cases, we introduce the concept of Generative Semantic Nursing (GSN), where we seek to intervene in the generative process on the fly during inference time to improve the faithfulness of the generated images. Using an attention-based formulation of GSN, dubbed Attend-and-Excite, we guide the model to refine the cross-attention units to attend to all subject tokens in the text prompt and strengthen — or excite — their activations, encouraging the model to generate all subjects described in the text prompt. We compare our approach to alternative approaches and demonstrate that it conveys the desired concepts more faithfully across a range of text prompts.
8 |
9 |
10 |
11 |
12 | [](https://huggingface.co/spaces/hysts/Attend-and-Excite)
13 | [](https://replicate.com/daanelson/attend-and-excite)
14 |
15 |
16 |
17 |
18 | Given a pre-trained text-to-image diffusion model (e.g., Stable Diffusion) our method, Attend-and-Excite, guides the generative model to modify the cross-attention values during the image synthesis process to generate images that more faithfully depict the input text prompt. Stable Diffusion alone (top row) struggles to generate multiple objects (e.g., a horse and a dog). However, by incorporating Attend-and-Excite (bottom row) to strengthen the subject tokens (marked in blue), we achieve images that are more semantically faithful with respect to the input text prompts.
19 |
20 |
21 | ## Description
22 | Official implementation of our Attend-and-Excite paper.
23 |
24 | ## Setup
25 |
26 | ### Environment
27 | Our code builds on the requirement of the official [Stable Diffusion repository](https://github.com/CompVis/stable-diffusion). To set up their environment, please run:
28 |
29 | ```
30 | conda env create -f environment/environment.yaml
31 | conda activate ldm
32 | ```
33 |
34 | On top of these requirements, we add several requirements which can be found in `environment/requirements.txt`. These requirements will be installed in the above command.
35 |
36 | ### Hugging Face Diffusers Library
37 | Our code relies also on Hugging Face's [diffusers](https://github.com/huggingface/diffusers) library for downloading the Stable Diffusion v1.4 model.
38 |
39 |
40 | ## Usage
41 |
42 |
43 |
44 |
45 | Example generations outputted by Stable Diffusion with Attend-and-Excite.
46 |
47 |
48 | To generate an image, you can simply run the `run.py` script. For example,
49 | ```
50 | python run.py --prompt "a cat and a dog" --seeds [0] --token_indices [2,5]
51 | ```
52 | Notes:
53 |
54 | - To apply Attend-and-Excite on Stable Diffusion 2.1, specify: `--sd_2_1 True`
55 | - You may run multiple seeds by passing a list of seeds. For example, `--seeds [0,1,2,3]`.
56 | - If you do not provide a list of which token indices to alter using `--token_indices`, we will split the text according to the Stable Diffusion's tokenizer and display the index of each token. You will then be able to input which indices you wish to alter.
57 | - If you wish to run the standard Stable Diffusion model without Attend-and-Excite, you can do so by passing `--run_standard_sd True`.
58 | - All parameters are defined in `config.py` and are set to their defaults according to the official paper.
59 |
60 | All generated images will be saved to the path `"{config.output_path}/{prompt}"`. We will also save a grid of all images (in the case of multiple seeds) under `config.output_path`.
61 |
62 | ### Float16 Precision
63 | When loading the Stable Diffusion model, you can use `torch.float16` in order to use less memory and attain faster inference:
64 | ```python
65 | stable = AttendAndExcitePipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to(device)
66 | ```
67 | Note that this may result in a slight degradation of results in some cases.
68 |
69 | ## Notebooks
70 | We provide Jupyter notebooks to reproduce the results from the paper for image generation and explainability via the
71 | cross-attention maps.
72 |
73 |
74 |
75 |
76 | Example cross-attention visualizations.
77 |
78 |
79 | ### Generation
80 | `notebooks/generate_images.ipynb` enables image generation using a free-form text prompt with and without Attend-and-Excite.
81 |
82 | ### Explainability
83 | `notebooks/explain.ipynb` produces a comparison of the cross-attention maps before and after applying Attend-and-Excite
84 | as seen in the illustration above.
85 | This notebook can be used to provide an explanation for the generations produced by Attend-and-Excite.
86 |
87 | ## Metrics
88 | In `metrics/` we provide code needed to reproduce the quantitative experiments presented in the paper:
89 | 1. In `compute_clip_similarity.py`, we provide the code needed for computing the image-based CLIP similarities. Here, we compute the CLIP-space similarities between the generated images and the guiding text prompt.
90 | 2. In `blip_captioning_and_clip_similarity.py`, we provide the code needed for computing the text-based CLIP similarities. Here, we generate captions for each generated image using BLIP and compute the CLIP-space similarities between the generated captions and the guiding text prompt.
91 | - Note: to run this script you need to install the `lavis` library. This can be done using `pip install lavis`.
92 |
93 | To run the scripts, you simply need to pass the output directory containing the generated images. The direcory structure should be as follows:
94 | ```
95 | outputs/
96 | |-- prompt_1/
97 | | |-- 0.png
98 | | |-- 1.png
99 | | |-- ...
100 | | |-- 64.png
101 | |-- prompt_2/
102 | | |-- 0.png
103 | | |-- 1.png
104 | | |-- ...
105 | | |-- 64.png
106 | ...
107 | ```
108 | The scripts will iterate through all the prompt outputs provided in the root output directory and aggregate results across all images.
109 |
110 | The metrics will be saved to a `json` file under the path specified by `--metrics_save_path`.
111 |
112 | ### Evaluation Prompts
113 | The prompts used in our quantitative evaluations can be found [here](https://github.com/AttendAndExcite/Attend-and-Excite/files/11336216/a.e_prompts.txt).
114 |
115 | ## Acknowledgements
116 | This code is builds on the code from the [diffusers](https://github.com/huggingface/diffusers) library as well as the [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/) codebase.
117 |
118 | ## Citation
119 | If you use this code for your research, please cite the following work:
120 | ```
121 | @misc{chefer2023attendandexcite,
122 | title={Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models},
123 | author={Hila Chefer and Yuval Alaluf and Yael Vinker and Lior Wolf and Daniel Cohen-Or},
124 | year={2023},
125 | eprint={2301.13826},
126 | archivePrefix={arXiv},
127 | primaryClass={cs.CV}
128 | }
129 | ```
130 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from pathlib import Path
3 | from typing import Dict, List
4 |
5 |
6 | @dataclass
7 | class RunConfig:
8 | # Guiding text prompt
9 | prompt: str
10 | # Whether to use Stable Diffusion v2.1
11 | sd_2_1: bool = False
12 | # Which token indices to alter with attend-and-excite
13 | token_indices: List[int] = None
14 | # Which random seeds to use when generating
15 | seeds: List[int] = field(default_factory=lambda: [42])
16 | # Path to save all outputs to
17 | output_path: Path = Path('./outputs')
18 | # Number of denoising steps
19 | n_inference_steps: int = 50
20 | # Text guidance scale
21 | guidance_scale: float = 7.5
22 | # Number of denoising steps to apply attend-and-excite
23 | max_iter_to_alter: int = 25
24 | # Resolution of UNet to compute attention maps over
25 | attention_res: int = 16
26 | # Whether to run standard SD or attend-and-excite
27 | run_standard_sd: bool = False
28 | # Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in
29 | thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8})
30 | # Scale factor for updating the denoised latent z_t
31 | scale_factor: int = 20
32 | # Start and end values used for scaling the scale factor - decays linearly with the denoising timestep
33 | scale_range: tuple = field(default_factory=lambda: (1.0, 0.5))
34 | # Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token
35 | smooth_attentions: bool = True
36 | # Standard deviation for the Gaussian smoothing
37 | sigma: float = 0.5
38 | # Kernel size for the Gaussian smoothing
39 | kernel_size: int = 3
40 | # Whether to save cross attention maps for the final results
41 | save_cross_attention_maps: bool = False
42 |
43 | def __post_init__(self):
44 | self.output_path.mkdir(exist_ok=True, parents=True)
45 |
--------------------------------------------------------------------------------
/docs/explainability.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/docs/explainability.jpg
--------------------------------------------------------------------------------
/docs/results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/docs/results.jpg
--------------------------------------------------------------------------------
/docs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/docs/teaser.jpg
--------------------------------------------------------------------------------
/environment/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.11.0
10 | - torchvision=0.12.0
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - diffusers
15 | - opencv-python==4.1.2.30
16 | - pudb==2019.2
17 | - invisible-watermark
18 | - imageio==2.9.0
19 | - imageio-ffmpeg==0.4.2
20 | - pytorch-lightning==1.4.2
21 | - omegaconf==2.1.1
22 | - test-tube>=0.7.5
23 | - streamlit>=0.73.1
24 | - einops==0.3.0
25 | - torch-fidelity==0.3.0
26 | - torchmetrics==0.6.0
27 | - kornia==0.6
28 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
29 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
30 | - -r requirements.txt
31 |
--------------------------------------------------------------------------------
/environment/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | opencv-python
3 | ipywidgets
4 | matplotlib
5 | pyrallis
6 | torch==1.12.0
7 | diffusers==0.12.1
8 | transformers==4.26.0
9 | jupyter
10 | accelerate
11 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/blip_captioning_and_clip_similarity.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 |
6 | import clip
7 | import numpy as np
8 | import pyrallis
9 | import torch
10 | from PIL import Image
11 | from lavis.models import load_model_and_preprocess
12 | from tqdm import tqdm
13 |
14 | sys.path.append(".")
15 | sys.path.append("..")
16 |
17 | from metrics.imagenet_utils import get_embedding_for_prompt, imagenet_templates
18 |
19 |
20 | @dataclass
21 | class EvalConfig:
22 | output_path: Path = Path("./outputs/")
23 | metrics_save_path: Path = Path("./metrics/")
24 |
25 | def __post_init__(self):
26 | self.metrics_save_path.mkdir(parents=True, exist_ok=True)
27 |
28 |
29 | @pyrallis.wrap()
30 | def run(config: EvalConfig):
31 | print("Loading CLIP model...")
32 | device = torch.device("cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
33 | model, preprocess = clip.load("ViT-B/16", device)
34 | model.eval()
35 | print("Done.")
36 |
37 | print("Loading BLIP model...")
38 | blip_model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco",
39 | is_eval=True, device=device)
40 | print("Done.")
41 |
42 | prompts = [p.name for p in config.output_path.glob("*") if p.is_dir()]
43 | print(f"Running on {len(prompts)} prompts...")
44 |
45 | results_per_prompt = {}
46 | for prompt in tqdm(prompts):
47 | print(f'Running on: "{prompt}"')
48 |
49 | # get all images for the given prompt
50 | image_paths = [p for p in (config.output_path / prompt).rglob('*') if p.suffix in ['.png', '.jpg']]
51 | images = [Image.open(p) for p in image_paths]
52 | image_names = [p.name for p in image_paths]
53 |
54 | with torch.no_grad():
55 | # extract prompt embeddings
56 | prompt_features = get_embedding_for_prompt(model, prompt, templates=imagenet_templates)
57 |
58 | # extract blip captions and embeddings
59 | blip_input_images = [vis_processors["eval"](image).unsqueeze(0).to(device) for image in images]
60 | blip_captions = [blip_model.generate({"image": image})[0] for image in blip_input_images]
61 | texts = [clip.tokenize([text]).cuda() for text in blip_captions]
62 | caption_embeddings = [model.encode_text(t) for t in texts]
63 | caption_embeddings = [embedding / embedding.norm(dim=-1, keepdim=True) for embedding in caption_embeddings]
64 |
65 | text_similarities = [(caption_embedding.float() @ prompt_features.T).item()
66 | for caption_embedding in caption_embeddings]
67 |
68 | results_per_prompt[prompt] = {
69 | 'text_similarities': text_similarities,
70 | 'captions': blip_captions,
71 | 'image_names': image_names,
72 | }
73 |
74 | # aggregate results
75 | total_average, total_std = aggregate_text_similarities(results_per_prompt)
76 | aggregated_results = {
77 | 'average_similarity': total_average,
78 | 'std_similarity': total_std,
79 | }
80 |
81 | with open(config.metrics_save_path / "blip_raw_metrics.json", 'w') as f:
82 | json.dump(results_per_prompt, f, sort_keys=True, indent=4)
83 | with open(config.metrics_save_path / "blip_aggregated_metrics.json", 'w') as f:
84 | json.dump(aggregated_results, f, sort_keys=True, indent=4)
85 |
86 |
87 | def aggregate_text_similarities(result_dict):
88 | all_averages = [result_dict[prompt]['text_similarities'] for prompt in result_dict]
89 | all_averages = np.array(all_averages).flatten()
90 | total_average = np.average(all_averages)
91 | total_std = np.std(all_averages)
92 | return total_average, total_std
93 |
94 |
95 | if __name__ == '__main__':
96 | run()
97 |
--------------------------------------------------------------------------------
/metrics/compute_clip_similarity.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 |
6 | import clip
7 | import numpy as np
8 | import pyrallis
9 | import torch
10 | from PIL import Image
11 | from tqdm import tqdm
12 |
13 | sys.path.append(".")
14 | sys.path.append("..")
15 |
16 | from metrics.imagenet_utils import get_embedding_for_prompt, imagenet_templates
17 |
18 |
19 | @dataclass
20 | class EvalConfig:
21 | output_path: Path = Path("./outputs/")
22 | metrics_save_path: Path = Path("./metrics/")
23 |
24 | def __post_init__(self):
25 | self.metrics_save_path.mkdir(parents=True, exist_ok=True)
26 |
27 |
28 | @pyrallis.wrap()
29 | def run(config: EvalConfig):
30 | print("Loading CLIP model...")
31 | device = torch.device("cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
32 | model, preprocess = clip.load("ViT-B/16", device)
33 | model.eval()
34 | print("Done.")
35 |
36 | prompts = [p.name for p in config.output_path.glob("*") if p.is_dir()]
37 | print(f"Running on {len(prompts)} prompts...")
38 |
39 | results_per_prompt = {}
40 | for prompt in tqdm(prompts):
41 |
42 | print(f'Running on: "{prompt}"')
43 |
44 | # get all images for the given prompt
45 | image_paths = [p for p in (config.output_path / prompt).rglob('*') if p.suffix in ['.png', '.jpg']]
46 | images = [Image.open(p) for p in image_paths]
47 | image_names = [p.name for p in image_paths]
48 | queries = [preprocess(image).unsqueeze(0).to(device) for image in images]
49 |
50 | with torch.no_grad():
51 |
52 | # split prompt into first and second halves
53 | if ' and ' in prompt:
54 | prompt_parts = prompt.split(' and ')
55 | elif ' with ' in prompt:
56 | prompt_parts = prompt.split(' with ')
57 | else:
58 | print(f"Unable to split prompt: {prompt}. "
59 | f"Looking for 'and' or 'with' for splitting! Skipping!")
60 | continue
61 |
62 | # extract texture features
63 | full_text_features = get_embedding_for_prompt(model, prompt, templates=imagenet_templates)
64 | first_half_features = get_embedding_for_prompt(model, prompt_parts[0], templates=imagenet_templates)
65 | second_half_features = get_embedding_for_prompt(model, prompt_parts[1], templates=imagenet_templates)
66 |
67 | # extract image features
68 | images_features = [model.encode_image(image) for image in queries]
69 | images_features = [feats / feats.norm(dim=-1, keepdim=True) for feats in images_features]
70 |
71 | # compute similarities
72 | full_text_similarities = [(feat.float() @ full_text_features.T).item() for feat in images_features]
73 | first_half_similarities = [(feat.float() @ first_half_features.T).item() for feat in images_features]
74 | second_half_similarities = [(feat.float() @ second_half_features.T).item() for feat in images_features]
75 |
76 | results_per_prompt[prompt] = {
77 | 'full_text': full_text_similarities,
78 | 'first_half': first_half_similarities,
79 | 'second_half': second_half_similarities,
80 | 'image_names': image_names,
81 | }
82 |
83 | # aggregate results
84 | aggregated_results = {
85 | 'full_text_aggregation': aggregate_by_full_text(results_per_prompt),
86 | 'min_first_second_aggregation': aggregate_by_min_half(results_per_prompt),
87 | }
88 |
89 | with open(config.metrics_save_path / "clip_raw_metrics.json", 'w') as f:
90 | json.dump(results_per_prompt, f, sort_keys=True, indent=4)
91 | with open(config.metrics_save_path / "clip_aggregated_metrics.json", 'w') as f:
92 | json.dump(aggregated_results, f, sort_keys=True, indent=4)
93 |
94 |
95 | def aggregate_by_min_half(d):
96 | """ Aggregate results for the minimum similarity score for each prompt. """
97 | min_per_half_res = [[min(a, b) for a, b in zip(d[prompt]["first_half"], d[prompt]["second_half"])] for prompt in d]
98 | min_per_half_res = np.array(min_per_half_res).flatten()
99 | return np.average(min_per_half_res)
100 |
101 |
102 | def aggregate_by_full_text(d):
103 | """ Aggregate results for the full text similarity for each prompt. """
104 | full_text_res = [v['full_text'] for v in d.values()]
105 | full_text_res = np.array(full_text_res).flatten()
106 | return np.average(full_text_res)
107 |
108 |
109 | if __name__ == '__main__':
110 | run()
111 |
--------------------------------------------------------------------------------
/metrics/imagenet_utils.py:
--------------------------------------------------------------------------------
1 | import clip
2 |
3 | imagenet_templates = [
4 | 'a bad photo of a {}.',
5 | 'a photo of many {}.',
6 | 'a sculpture of a {}.',
7 | 'a photo of the hard to see {}.',
8 | 'a low resolution photo of the {}.',
9 | 'a rendering of a {}.',
10 | 'graffiti of a {}.',
11 | 'a bad photo of the {}.',
12 | 'a cropped photo of the {}.',
13 | 'a tattoo of a {}.',
14 | 'the embroidered {}.',
15 | 'a photo of a hard to see {}.',
16 | 'a bright photo of a {}.',
17 | 'a photo of a clean {}.',
18 | 'a photo of a dirty {}.',
19 | 'a dark photo of the {}.',
20 | 'a drawing of a {}.',
21 | 'a photo of my {}.',
22 | 'the plastic {}.',
23 | 'a photo of the cool {}.',
24 | 'a close-up photo of a {}.',
25 | 'a black and white photo of the {}.',
26 | 'a painting of the {}.',
27 | 'a painting of a {}.',
28 | 'a pixelated photo of the {}.',
29 | 'a sculpture of the {}.',
30 | 'a bright photo of the {}.',
31 | 'a cropped photo of a {}.',
32 | 'a plastic {}.',
33 | 'a photo of the dirty {}.',
34 | 'a jpeg corrupted photo of a {}.',
35 | 'a blurry photo of the {}.',
36 | 'a photo of the {}.',
37 | 'a good photo of the {}.',
38 | 'a rendering of the {}.',
39 | 'a {} in a video game.',
40 | 'a photo of one {}.',
41 | 'a doodle of a {}.',
42 | 'a close-up photo of the {}.',
43 | 'a photo of a {}.',
44 | 'the origami {}.',
45 | 'the {} in a video game.',
46 | 'a sketch of a {}.',
47 | 'a doodle of the {}.',
48 | 'a origami {}.',
49 | 'a low resolution photo of a {}.',
50 | 'the toy {}.',
51 | 'a rendition of the {}.',
52 | 'a photo of the clean {}.',
53 | 'a photo of a large {}.',
54 | 'a rendition of a {}.',
55 | 'a photo of a nice {}.',
56 | 'a photo of a weird {}.',
57 | 'a blurry photo of a {}.',
58 | 'a cartoon {}.',
59 | 'art of a {}.',
60 | 'a sketch of the {}.',
61 | 'a embroidered {}.',
62 | 'a pixelated photo of a {}.',
63 | 'itap of the {}.',
64 | 'a jpeg corrupted photo of the {}.',
65 | 'a good photo of a {}.',
66 | 'a plushie {}.',
67 | 'a photo of the nice {}.',
68 | 'a photo of the small {}.',
69 | 'a photo of the weird {}.',
70 | 'the cartoon {}.',
71 | 'art of the {}.',
72 | 'a drawing of the {}.',
73 | 'a photo of the large {}.',
74 | 'a black and white photo of a {}.',
75 | 'the plushie {}.',
76 | 'a dark photo of a {}.',
77 | 'itap of a {}.',
78 | 'graffiti of the {}.',
79 | 'a toy {}.',
80 | 'itap of my {}.',
81 | 'a photo of a cool {}.',
82 | 'a photo of a small {}.',
83 | 'a tattoo of the {}.',
84 | ]
85 |
86 |
87 | def get_embedding_for_prompt(model, prompt, templates):
88 | texts = [template.format(prompt) for template in templates] # format with class
89 | texts = [t.replace('a a', 'a') for t in texts] # remove double a's
90 | texts = [t.replace('the a', 'a') for t in texts] # remove double a's
91 | texts = clip.tokenize(texts).cuda() # tokenize
92 | class_embeddings = model.encode_text(texts) # embed with text encoder
93 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
94 | class_embedding = class_embeddings.mean(dim=0)
95 | class_embedding /= class_embedding.norm()
96 | return class_embedding.float()
97 |
--------------------------------------------------------------------------------
/pipeline_attend_and_excite.py:
--------------------------------------------------------------------------------
1 |
2 | import inspect
3 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | from torch.nn import functional as F
8 |
9 | from packaging import version
10 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
11 |
12 | from diffusers.configuration_utils import FrozenDict
13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
14 | from diffusers.schedulers import KarrasDiffusionSchedulers
15 | from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
16 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
17 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
18 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
19 |
20 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
21 |
22 | from utils.gaussian_smoothing import GaussianSmoothing
23 | from utils.ptp_utils import AttentionStore, aggregate_attention
24 |
25 | logger = logging.get_logger(__name__)
26 |
27 | class AttendAndExcitePipeline(StableDiffusionPipeline):
28 | r"""
29 | Pipeline for text-to-image generation using Stable Diffusion.
30 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
31 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
32 | Args:
33 | vae ([`AutoencoderKL`]):
34 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
35 | text_encoder ([`CLIPTextModel`]):
36 | Frozen text-encoder. Stable Diffusion uses the text portion of
37 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
38 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
39 | tokenizer (`CLIPTokenizer`):
40 | Tokenizer of class
41 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
42 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
43 | scheduler ([`SchedulerMixin`]):
44 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
45 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
46 | safety_checker ([`StableDiffusionSafetyChecker`]):
47 | Classification module that estimates whether generated images could be considered offensive or harmful.
48 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
49 | feature_extractor ([`CLIPFeatureExtractor`]):
50 | Model that extracts features from generated images to be used as inputs for the `safety_checker`.
51 | """
52 | _optional_components = ["safety_checker", "feature_extractor"]
53 |
54 | def _encode_prompt(
55 | self,
56 | prompt,
57 | device,
58 | num_images_per_prompt,
59 | do_classifier_free_guidance,
60 | negative_prompt=None,
61 | prompt_embeds: Optional[torch.FloatTensor] = None,
62 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
63 | ):
64 | r"""
65 | Encodes the prompt into text encoder hidden states.
66 | Args:
67 | prompt (`str` or `List[str]`, *optional*):
68 | prompt to be encoded
69 | device: (`torch.device`):
70 | torch device
71 | num_images_per_prompt (`int`):
72 | number of images that should be generated per prompt
73 | do_classifier_free_guidance (`bool`):
74 | whether to use classifier free guidance or not
75 | negative_ prompt (`str` or `List[str]`, *optional*):
76 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
77 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
78 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
79 | prompt_embeds (`torch.FloatTensor`, *optional*):
80 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
81 | provided, text embeddings will be generated from `prompt` input argument.
82 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
83 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
84 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
85 | argument.
86 | """
87 | if prompt is not None and isinstance(prompt, str):
88 | batch_size = 1
89 | elif prompt is not None and isinstance(prompt, list):
90 | batch_size = len(prompt)
91 | else:
92 | batch_size = prompt_embeds.shape[0]
93 |
94 | if prompt_embeds is None:
95 | text_inputs = self.tokenizer(
96 | prompt,
97 | padding="max_length",
98 | max_length=self.tokenizer.model_max_length,
99 | truncation=True,
100 | return_tensors="pt",
101 | )
102 | text_input_ids = text_inputs.input_ids
103 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
104 |
105 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
106 | text_input_ids, untruncated_ids
107 | ):
108 | removed_text = self.tokenizer.batch_decode(
109 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
110 | )
111 | logger.warning(
112 | "The following part of your input was truncated because CLIP can only handle sequences up to"
113 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
114 | )
115 |
116 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
117 | attention_mask = text_inputs.attention_mask.to(device)
118 | else:
119 | attention_mask = None
120 |
121 | prompt_embeds = self.text_encoder(
122 | text_input_ids.to(device),
123 | attention_mask=attention_mask,
124 | )
125 | prompt_embeds = prompt_embeds[0]
126 |
127 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
128 |
129 | bs_embed, seq_len, _ = prompt_embeds.shape
130 | # duplicate text embeddings for each generation per prompt, using mps friendly method
131 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
132 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
133 |
134 | # get unconditional embeddings for classifier free guidance
135 | if do_classifier_free_guidance and negative_prompt_embeds is None:
136 | uncond_tokens: List[str]
137 | if negative_prompt is None:
138 | uncond_tokens = [""] * batch_size
139 | elif type(prompt) is not type(negative_prompt):
140 | raise TypeError(
141 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
142 | f" {type(prompt)}."
143 | )
144 | elif isinstance(negative_prompt, str):
145 | uncond_tokens = [negative_prompt]
146 | elif batch_size != len(negative_prompt):
147 | raise ValueError(
148 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
149 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
150 | " the batch size of `prompt`."
151 | )
152 | else:
153 | uncond_tokens = negative_prompt
154 |
155 | max_length = prompt_embeds.shape[1]
156 | uncond_input = self.tokenizer(
157 | uncond_tokens,
158 | padding="max_length",
159 | max_length=max_length,
160 | truncation=True,
161 | return_tensors="pt",
162 | )
163 |
164 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
165 | attention_mask = uncond_input.attention_mask.to(device)
166 | else:
167 | attention_mask = None
168 |
169 | negative_prompt_embeds = self.text_encoder(
170 | uncond_input.input_ids.to(device),
171 | attention_mask=attention_mask,
172 | )
173 | negative_prompt_embeds = negative_prompt_embeds[0]
174 |
175 | if do_classifier_free_guidance:
176 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
177 | seq_len = negative_prompt_embeds.shape[1]
178 |
179 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
180 |
181 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
182 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
183 |
184 | # For classifier free guidance, we need to do two forward passes.
185 | # Here we concatenate the unconditional and text embeddings into a single batch
186 | # to avoid doing two forward passes
187 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
188 |
189 | return text_inputs, prompt_embeds
190 |
191 | def _compute_max_attention_per_index(self,
192 | attention_maps: torch.Tensor,
193 | indices_to_alter: List[int],
194 | smooth_attentions: bool = False,
195 | sigma: float = 0.5,
196 | kernel_size: int = 3,
197 | normalize_eot: bool = False) -> List[torch.Tensor]:
198 | """ Computes the maximum attention value for each of the tokens we wish to alter. """
199 | last_idx = -1
200 | if normalize_eot:
201 | prompt = self.prompt
202 | if isinstance(self.prompt, list):
203 | prompt = self.prompt[0]
204 | last_idx = len(self.tokenizer(prompt)['input_ids']) - 1
205 | attention_for_text = attention_maps[:, :, 1:last_idx]
206 | attention_for_text *= 100
207 | attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)
208 |
209 | # Shift indices since we removed the first token
210 | indices_to_alter = [index - 1 for index in indices_to_alter]
211 |
212 | # Extract the maximum values
213 | max_indices_list = []
214 | for i in indices_to_alter:
215 | image = attention_for_text[:, :, i]
216 | if smooth_attentions:
217 | smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
218 | input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
219 | image = smoothing(input).squeeze(0).squeeze(0)
220 | max_indices_list.append(image.max())
221 | return max_indices_list
222 |
223 | def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionStore,
224 | indices_to_alter: List[int],
225 | attention_res: int = 16,
226 | smooth_attentions: bool = False,
227 | sigma: float = 0.5,
228 | kernel_size: int = 3,
229 | normalize_eot: bool = False):
230 | """ Aggregates the attention for each token and computes the max activation value for each token to alter. """
231 | attention_maps = aggregate_attention(
232 | attention_store=attention_store,
233 | res=attention_res,
234 | from_where=("up", "down", "mid"),
235 | is_cross=True,
236 | select=0)
237 | max_attention_per_index = self._compute_max_attention_per_index(
238 | attention_maps=attention_maps,
239 | indices_to_alter=indices_to_alter,
240 | smooth_attentions=smooth_attentions,
241 | sigma=sigma,
242 | kernel_size=kernel_size,
243 | normalize_eot=normalize_eot)
244 | return max_attention_per_index
245 |
246 | @staticmethod
247 | def _compute_loss(max_attention_per_index: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor:
248 | """ Computes the attend-and-excite loss using the maximum attention value for each token. """
249 | losses = [max(0, 1. - curr_max) for curr_max in max_attention_per_index]
250 | loss = max(losses)
251 | if return_losses:
252 | return loss, losses
253 | else:
254 | return loss
255 |
256 | @staticmethod
257 | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor:
258 | """ Update the latent according to the computed loss. """
259 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0]
260 | latents = latents - step_size * grad_cond
261 | return latents
262 |
263 | def _perform_iterative_refinement_step(self,
264 | latents: torch.Tensor,
265 | indices_to_alter: List[int],
266 | loss: torch.Tensor,
267 | threshold: float,
268 | text_embeddings: torch.Tensor,
269 | text_input,
270 | attention_store: AttentionStore,
271 | step_size: float,
272 | t: int,
273 | attention_res: int = 16,
274 | smooth_attentions: bool = True,
275 | sigma: float = 0.5,
276 | kernel_size: int = 3,
277 | max_refinement_steps: int = 20,
278 | normalize_eot: bool = False):
279 | """
280 | Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent
281 | code according to our loss objective until the given threshold is reached for all tokens.
282 | """
283 | iteration = 0
284 | target_loss = max(0, 1. - threshold)
285 | while loss > target_loss:
286 | iteration += 1
287 |
288 | latents = latents.clone().detach().requires_grad_(True)
289 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
290 | self.unet.zero_grad()
291 |
292 | # Get max activation value for each subject token
293 | max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
294 | attention_store=attention_store,
295 | indices_to_alter=indices_to_alter,
296 | attention_res=attention_res,
297 | smooth_attentions=smooth_attentions,
298 | sigma=sigma,
299 | kernel_size=kernel_size,
300 | normalize_eot=normalize_eot
301 | )
302 |
303 | loss, losses = self._compute_loss(max_attention_per_index, return_losses=True)
304 |
305 | if loss != 0:
306 | latents = self._update_latent(latents, loss, step_size)
307 |
308 | with torch.no_grad():
309 | noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample
310 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
311 |
312 | try:
313 | low_token = np.argmax([l.item() if type(l) != int else l for l in losses])
314 | except Exception as e:
315 | print(e) # catch edge case :)
316 | low_token = np.argmax(losses)
317 |
318 | low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]])
319 | print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index[low_token]}')
320 |
321 | if iteration >= max_refinement_steps:
322 | print(f'\t Exceeded max number of iterations ({max_refinement_steps})! '
323 | f'Finished with a max attention of {max_attention_per_index[low_token]}')
324 | break
325 |
326 | # Run one more time but don't compute gradients and update the latents.
327 | # We just need to compute the new loss - the grad update will occur below
328 | latents = latents.clone().detach().requires_grad_(True)
329 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample
330 | self.unet.zero_grad()
331 |
332 | # Get max activation value for each subject token
333 | max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
334 | attention_store=attention_store,
335 | indices_to_alter=indices_to_alter,
336 | attention_res=attention_res,
337 | smooth_attentions=smooth_attentions,
338 | sigma=sigma,
339 | kernel_size=kernel_size,
340 | normalize_eot=normalize_eot)
341 | loss, losses = self._compute_loss(max_attention_per_index, return_losses=True)
342 | print(f"\t Finished with loss of: {loss}")
343 | return loss, latents, max_attention_per_index
344 |
345 | @torch.no_grad()
346 | def __call__(
347 | self,
348 | prompt: Union[str, List[str]],
349 | attention_store: AttentionStore,
350 | indices_to_alter: List[int],
351 | attention_res: int = 16,
352 | height: Optional[int] = None,
353 | width: Optional[int] = None,
354 | num_inference_steps: int = 50,
355 | guidance_scale: float = 7.5,
356 | negative_prompt: Optional[Union[str, List[str]]] = None,
357 | num_images_per_prompt: Optional[int] = 1,
358 | eta: float = 0.0,
359 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
360 | latents: Optional[torch.FloatTensor] = None,
361 | prompt_embeds: Optional[torch.FloatTensor] = None,
362 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
363 | output_type: Optional[str] = "pil",
364 | return_dict: bool = True,
365 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
366 | callback_steps: Optional[int] = 1,
367 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
368 | max_iter_to_alter: Optional[int] = 25,
369 | run_standard_sd: bool = False,
370 | thresholds: Optional[dict] = {0: 0.05, 10: 0.5, 20: 0.8},
371 | scale_factor: int = 20,
372 | scale_range: Tuple[float, float] = (1., 0.5),
373 | smooth_attentions: bool = True,
374 | sigma: float = 0.5,
375 | kernel_size: int = 3,
376 | sd_2_1: bool = False,
377 | ):
378 | r"""
379 | Function invoked when calling the pipeline for generation.
380 | Args:
381 | prompt (`str` or `List[str]`, *optional*):
382 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
383 | instead.
384 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
385 | The height in pixels of the generated image.
386 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
387 | The width in pixels of the generated image.
388 | num_inference_steps (`int`, *optional*, defaults to 50):
389 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
390 | expense of slower inference.
391 | guidance_scale (`float`, *optional*, defaults to 7.5):
392 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
393 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
394 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
395 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
396 | usually at the expense of lower image quality.
397 | negative_prompt (`str` or `List[str]`, *optional*):
398 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
399 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
400 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
401 | num_images_per_prompt (`int`, *optional*, defaults to 1):
402 | The number of images to generate per prompt.
403 | eta (`float`, *optional*, defaults to 0.0):
404 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
405 | [`schedulers.DDIMScheduler`], will be ignored for others.
406 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
407 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
408 | to make generation deterministic.
409 | latents (`torch.FloatTensor`, *optional*):
410 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
411 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
412 | tensor will ge generated by sampling using the supplied random `generator`.
413 | prompt_embeds (`torch.FloatTensor`, *optional*):
414 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
415 | provided, text embeddings will be generated from `prompt` input argument.
416 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
417 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
418 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
419 | argument.
420 | output_type (`str`, *optional*, defaults to `"pil"`):
421 | The output format of the generate image. Choose between
422 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
423 | return_dict (`bool`, *optional*, defaults to `True`):
424 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
425 | plain tuple.
426 | callback (`Callable`, *optional*):
427 | A function that will be called every `callback_steps` steps during inference. The function will be
428 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
429 | callback_steps (`int`, *optional*, defaults to 1):
430 | The frequency at which the `callback` function will be called. If not specified, the callback will be
431 | called at every step.
432 | cross_attention_kwargs (`dict`, *optional*):
433 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
434 | `self.processor` in
435 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
436 | Examples:
437 | Returns:
438 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
439 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
440 | When returning a tuple, the first element is a list with the generated images, and the second element is a
441 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
442 | (nsfw) content, according to the `safety_checker`.
443 | :type attention_store: object
444 | """
445 | # 0. Default height and width to unet
446 | height = height or self.unet.config.sample_size * self.vae_scale_factor
447 | width = width or self.unet.config.sample_size * self.vae_scale_factor
448 |
449 | # 1. Check inputs. Raise error if not correct
450 | self.check_inputs(
451 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
452 | )
453 |
454 | # 2. Define call parameters
455 | self.prompt = prompt
456 | if prompt is not None and isinstance(prompt, str):
457 | batch_size = 1
458 | elif prompt is not None and isinstance(prompt, list):
459 | batch_size = len(prompt)
460 | else:
461 | batch_size = prompt_embeds.shape[0]
462 |
463 | device = self._execution_device
464 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
465 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
466 | # corresponds to doing no classifier free guidance.
467 | do_classifier_free_guidance = guidance_scale > 1.0
468 |
469 | # 3. Encode input prompt
470 | text_inputs, prompt_embeds = self._encode_prompt(
471 | prompt,
472 | device,
473 | num_images_per_prompt,
474 | do_classifier_free_guidance,
475 | negative_prompt,
476 | prompt_embeds=prompt_embeds,
477 | negative_prompt_embeds=negative_prompt_embeds,
478 | )
479 |
480 | # 4. Prepare timesteps
481 | self.scheduler.set_timesteps(num_inference_steps, device=device)
482 | timesteps = self.scheduler.timesteps
483 |
484 | # 5. Prepare latent variables
485 | num_channels_latents = self.unet.in_channels
486 | latents = self.prepare_latents(
487 | batch_size * num_images_per_prompt,
488 | num_channels_latents,
489 | height,
490 | width,
491 | prompt_embeds.dtype,
492 | device,
493 | generator,
494 | latents,
495 | )
496 |
497 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
498 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
499 |
500 | scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps))
501 |
502 | if max_iter_to_alter is None:
503 | max_iter_to_alter = len(self.scheduler.timesteps) + 1
504 |
505 | # 7. Denoising loop
506 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
507 | with self.progress_bar(total=num_inference_steps) as progress_bar:
508 | for i, t in enumerate(timesteps):
509 |
510 | with torch.enable_grad():
511 |
512 | latents = latents.clone().detach().requires_grad_(True)
513 |
514 | # Forward pass of denoising with text conditioning
515 | noise_pred_text = self.unet(latents, t,
516 | encoder_hidden_states=prompt_embeds[1].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs).sample
517 | self.unet.zero_grad()
518 |
519 | # Get max activation value for each subject token
520 | max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
521 | attention_store=attention_store,
522 | indices_to_alter=indices_to_alter,
523 | attention_res=attention_res,
524 | smooth_attentions=smooth_attentions,
525 | sigma=sigma,
526 | kernel_size=kernel_size,
527 | normalize_eot=sd_2_1)
528 |
529 | if not run_standard_sd:
530 |
531 | loss = self._compute_loss(max_attention_per_index=max_attention_per_index)
532 |
533 | # If this is an iterative refinement step, verify we have reached the desired threshold for all
534 | if i in thresholds.keys() and loss > 1. - thresholds[i]:
535 | del noise_pred_text
536 | torch.cuda.empty_cache()
537 | loss, latents, max_attention_per_index = self._perform_iterative_refinement_step(
538 | latents=latents,
539 | indices_to_alter=indices_to_alter,
540 | loss=loss,
541 | threshold=thresholds[i],
542 | text_embeddings=prompt_embeds,
543 | text_input=text_inputs,
544 | attention_store=attention_store,
545 | step_size=scale_factor * np.sqrt(scale_range[i]),
546 | t=t,
547 | attention_res=attention_res,
548 | smooth_attentions=smooth_attentions,
549 | sigma=sigma,
550 | kernel_size=kernel_size,
551 | normalize_eot=sd_2_1)
552 |
553 | # Perform gradient update
554 | if i < max_iter_to_alter:
555 | loss = self._compute_loss(max_attention_per_index=max_attention_per_index)
556 | if loss != 0:
557 | latents = self._update_latent(latents=latents, loss=loss,
558 | step_size=scale_factor * np.sqrt(scale_range[i]))
559 | print(f'Iteration {i} | Loss: {loss:0.4f}')
560 |
561 | # expand the latents if we are doing classifier free guidance
562 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
563 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
564 |
565 | # predict the noise residual
566 | noise_pred = self.unet(
567 | latent_model_input,
568 | t,
569 | encoder_hidden_states=prompt_embeds,
570 | cross_attention_kwargs=cross_attention_kwargs,
571 | ).sample
572 |
573 | # perform guidance
574 | if do_classifier_free_guidance:
575 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
576 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
577 |
578 | # compute the previous noisy sample x_t -> x_t-1
579 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
580 |
581 | # call the callback, if provided
582 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
583 | progress_bar.update()
584 | if callback is not None and i % callback_steps == 0:
585 | callback(i, t, latents)
586 |
587 | # 8. Post-processing
588 | image = self.decode_latents(latents)
589 |
590 | # 9. Run safety checker
591 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
592 |
593 | # 10. Convert to PIL
594 | if output_type == "pil":
595 | image = self.numpy_to_pil(image)
596 |
597 | if not return_dict:
598 | return (image, has_nsfw_concept)
599 |
600 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
601 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import pprint
2 | from typing import List
3 |
4 | import pyrallis
5 | import torch
6 | from PIL import Image
7 |
8 | from config import RunConfig
9 | from pipeline_attend_and_excite import AttendAndExcitePipeline
10 | from utils import ptp_utils, vis_utils
11 | from utils.ptp_utils import AttentionStore
12 |
13 | import warnings
14 | warnings.filterwarnings("ignore", category=UserWarning)
15 |
16 |
17 | def load_model(config: RunConfig):
18 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
19 |
20 | if config.sd_2_1:
21 | stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
22 | else:
23 | stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
24 | stable = AttendAndExcitePipeline.from_pretrained(stable_diffusion_version).to(device)
25 | return stable
26 |
27 |
28 | def get_indices_to_alter(stable, prompt: str) -> List[int]:
29 | token_idx_to_word = {idx: stable.tokenizer.decode(t)
30 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids'])
31 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1}
32 | pprint.pprint(token_idx_to_word)
33 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to "
34 | "alter (e.g., 2,5): ")
35 | token_indices = [int(i) for i in token_indices.split(",")]
36 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}")
37 | return token_indices
38 |
39 |
40 | def run_on_prompt(prompt: List[str],
41 | model: AttendAndExcitePipeline,
42 | controller: AttentionStore,
43 | token_indices: List[int],
44 | seed: torch.Generator,
45 | config: RunConfig) -> Image.Image:
46 | if controller is not None:
47 | ptp_utils.register_attention_control(model, controller)
48 | outputs = model(prompt=prompt,
49 | attention_store=controller,
50 | indices_to_alter=token_indices,
51 | attention_res=config.attention_res,
52 | guidance_scale=config.guidance_scale,
53 | generator=seed,
54 | num_inference_steps=config.n_inference_steps,
55 | max_iter_to_alter=config.max_iter_to_alter,
56 | run_standard_sd=config.run_standard_sd,
57 | thresholds=config.thresholds,
58 | scale_factor=config.scale_factor,
59 | scale_range=config.scale_range,
60 | smooth_attentions=config.smooth_attentions,
61 | sigma=config.sigma,
62 | kernel_size=config.kernel_size,
63 | sd_2_1=config.sd_2_1)
64 | image = outputs.images[0]
65 | return image
66 |
67 |
68 | @pyrallis.wrap()
69 | def main(config: RunConfig):
70 | stable = load_model(config)
71 | token_indices = get_indices_to_alter(stable, config.prompt) if config.token_indices is None else config.token_indices
72 |
73 | images = []
74 | for seed in config.seeds:
75 | print(f"Seed: {seed}")
76 | g = torch.Generator('cuda').manual_seed(seed)
77 | controller = AttentionStore()
78 | image = run_on_prompt(prompt=config.prompt,
79 | model=stable,
80 | controller=controller,
81 | token_indices=token_indices,
82 | seed=g,
83 | config=config)
84 | prompt_output_path = config.output_path / config.prompt
85 | prompt_output_path.mkdir(exist_ok=True, parents=True)
86 | image.save(prompt_output_path / f'{seed}.png')
87 | images.append(image)
88 |
89 | # save a grid of results across all seeds
90 | joined_image = vis_utils.get_image_grid(images)
91 | joined_image.save(config.output_path / f'{config.prompt}.png')
92 |
93 |
94 | if __name__ == '__main__':
95 | main()
96 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/utils/__init__.py
--------------------------------------------------------------------------------
/utils/gaussian_smoothing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class GaussianSmoothing(nn.Module):
9 | """
10 | Apply gaussian smoothing on a
11 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
12 | in the input using a depthwise convolution.
13 | Arguments:
14 | channels (int, sequence): Number of channels of the input tensors. Output will
15 | have this number of channels as well.
16 | kernel_size (int, sequence): Size of the gaussian kernel.
17 | sigma (float, sequence): Standard deviation of the gaussian kernel.
18 | dim (int, optional): The number of dimensions of the data.
19 | Default value is 2 (spatial).
20 | """
21 | def __init__(self, channels, kernel_size, sigma, dim=2):
22 | super(GaussianSmoothing, self).__init__()
23 | if isinstance(kernel_size, numbers.Number):
24 | kernel_size = [kernel_size] * dim
25 | if isinstance(sigma, numbers.Number):
26 | sigma = [sigma] * dim
27 |
28 | # The gaussian kernel is the product of the
29 | # gaussian function of each dimension.
30 | kernel = 1
31 | meshgrids = torch.meshgrid(
32 | [
33 | torch.arange(size, dtype=torch.float32)
34 | for size in kernel_size
35 | ]
36 | )
37 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
38 | mean = (size - 1) / 2
39 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
40 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
41 |
42 | # Make sure sum of values in gaussian kernel equals 1.
43 | kernel = kernel / torch.sum(kernel)
44 |
45 | # Reshape to depthwise convolutional weight
46 | kernel = kernel.view(1, 1, *kernel.size())
47 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
48 |
49 | self.register_buffer('weight', kernel)
50 | self.groups = channels
51 |
52 | if dim == 1:
53 | self.conv = F.conv1d
54 | elif dim == 2:
55 | self.conv = F.conv2d
56 | elif dim == 3:
57 | self.conv = F.conv3d
58 | else:
59 | raise RuntimeError(
60 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
61 | )
62 |
63 | def forward(self, input):
64 | """
65 | Apply gaussian filter to input.
66 | Arguments:
67 | input (torch.Tensor): Input to apply gaussian filter on.
68 | Returns:
69 | filtered (torch.Tensor): Filtered output.
70 | """
71 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
72 |
73 |
74 | class AverageSmoothing(nn.Module):
75 | """
76 | Apply average smoothing on a
77 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
78 | in the input using a depthwise convolution.
79 | Arguments:
80 | channels (int, sequence): Number of channels of the input tensors. Output will
81 | have this number of channels as well.
82 | kernel_size (int, sequence): Size of the average kernel.
83 | sigma (float, sequence): Standard deviation of the rage kernel.
84 | dim (int, optional): The number of dimensions of the data.
85 | Default value is 2 (spatial).
86 | """
87 | def __init__(self, channels, kernel_size, dim=2):
88 | super(AverageSmoothing, self).__init__()
89 |
90 | # Make sure sum of values in gaussian kernel equals 1.
91 | kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size)
92 |
93 | # Reshape to depthwise convolutional weight
94 | kernel = kernel.view(1, 1, *kernel.size())
95 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
96 |
97 | self.register_buffer('weight', kernel)
98 | self.groups = channels
99 |
100 | if dim == 1:
101 | self.conv = F.conv1d
102 | elif dim == 2:
103 | self.conv = F.conv2d
104 | elif dim == 3:
105 | self.conv = F.conv3d
106 | else:
107 | raise RuntimeError(
108 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
109 | )
110 |
111 | def forward(self, input):
112 | """
113 | Apply average filter to input.
114 | Arguments:
115 | input (torch.Tensor): Input to apply average filter on.
116 | Returns:
117 | filtered (torch.Tensor): Filtered output.
118 | """
119 | return self.conv(input, weight=self.weight, groups=self.groups)
120 |
--------------------------------------------------------------------------------
/utils/ptp_utils.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from IPython.display import display
7 | from PIL import Image
8 | from typing import Union, Tuple, List
9 |
10 | from diffusers.models.cross_attention import CrossAttention
11 |
12 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray:
13 | h, w, c = image.shape
14 | offset = int(h * .2)
15 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
16 | font = cv2.FONT_HERSHEY_SIMPLEX
17 | img[:h] = image
18 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
19 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
20 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
21 | return img
22 |
23 |
24 | def view_images(images: Union[np.ndarray, List],
25 | num_rows: int = 1,
26 | offset_ratio: float = 0.02,
27 | display_image: bool = True) -> Image.Image:
28 | """ Displays a list of images in a grid. """
29 | if type(images) is list:
30 | num_empty = len(images) % num_rows
31 | elif images.ndim == 4:
32 | num_empty = images.shape[0] % num_rows
33 | else:
34 | images = [images]
35 | num_empty = 0
36 |
37 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
38 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
39 | num_items = len(images)
40 |
41 | h, w, c = images[0].shape
42 | offset = int(h * offset_ratio)
43 | num_cols = num_items // num_rows
44 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
45 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
46 | for i in range(num_rows):
47 | for j in range(num_cols):
48 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
49 | i * num_cols + j]
50 |
51 | pil_img = Image.fromarray(image_)
52 | if display_image:
53 | display(pil_img)
54 | return pil_img
55 |
56 |
57 | class AttendExciteCrossAttnProcessor:
58 |
59 | def __init__(self, attnstore, place_in_unet):
60 | super().__init__()
61 | self.attnstore = attnstore
62 | self.place_in_unet = place_in_unet
63 |
64 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
65 | batch_size, sequence_length, _ = hidden_states.shape
66 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
67 |
68 | query = attn.to_q(hidden_states)
69 |
70 | is_cross = encoder_hidden_states is not None
71 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
72 | key = attn.to_k(encoder_hidden_states)
73 | value = attn.to_v(encoder_hidden_states)
74 |
75 | query = attn.head_to_batch_dim(query)
76 | key = attn.head_to_batch_dim(key)
77 | value = attn.head_to_batch_dim(value)
78 |
79 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
80 |
81 | self.attnstore(attention_probs, is_cross, self.place_in_unet)
82 |
83 | hidden_states = torch.bmm(attention_probs, value)
84 | hidden_states = attn.batch_to_head_dim(hidden_states)
85 |
86 | # linear proj
87 | hidden_states = attn.to_out[0](hidden_states)
88 | # dropout
89 | hidden_states = attn.to_out[1](hidden_states)
90 |
91 | return hidden_states
92 |
93 |
94 | def register_attention_control(model, controller):
95 |
96 | attn_procs = {}
97 | cross_att_count = 0
98 | for name in model.unet.attn_processors.keys():
99 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
100 | if name.startswith("mid_block"):
101 | hidden_size = model.unet.config.block_out_channels[-1]
102 | place_in_unet = "mid"
103 | elif name.startswith("up_blocks"):
104 | block_id = int(name[len("up_blocks.")])
105 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
106 | place_in_unet = "up"
107 | elif name.startswith("down_blocks"):
108 | block_id = int(name[len("down_blocks.")])
109 | hidden_size = model.unet.config.block_out_channels[block_id]
110 | place_in_unet = "down"
111 | else:
112 | continue
113 |
114 | cross_att_count += 1
115 | attn_procs[name] = AttendExciteCrossAttnProcessor(
116 | attnstore=controller, place_in_unet=place_in_unet
117 | )
118 |
119 | model.unet.set_attn_processor(attn_procs)
120 | controller.num_att_layers = cross_att_count
121 |
122 |
123 | class AttentionControl(abc.ABC):
124 |
125 | def step_callback(self, x_t):
126 | return x_t
127 |
128 | def between_steps(self):
129 | return
130 |
131 | @property
132 | def num_uncond_att_layers(self):
133 | return 0
134 |
135 | @abc.abstractmethod
136 | def forward(self, attn, is_cross: bool, place_in_unet: str):
137 | raise NotImplementedError
138 |
139 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
140 | if self.cur_att_layer >= self.num_uncond_att_layers:
141 | self.forward(attn, is_cross, place_in_unet)
142 | self.cur_att_layer += 1
143 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
144 | self.cur_att_layer = 0
145 | self.cur_step += 1
146 | self.between_steps()
147 |
148 | def reset(self):
149 | self.cur_step = 0
150 | self.cur_att_layer = 0
151 |
152 | def __init__(self):
153 | self.cur_step = 0
154 | self.num_att_layers = -1
155 | self.cur_att_layer = 0
156 |
157 |
158 | class EmptyControl(AttentionControl):
159 |
160 | def forward(self, attn, is_cross: bool, place_in_unet: str):
161 | return attn
162 |
163 |
164 | class AttentionStore(AttentionControl):
165 |
166 | @staticmethod
167 | def get_empty_store():
168 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
169 | "down_self": [], "mid_self": [], "up_self": []}
170 |
171 | def forward(self, attn, is_cross: bool, place_in_unet: str):
172 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
173 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead
174 | self.step_store[key].append(attn)
175 | return attn
176 |
177 | def between_steps(self):
178 | self.attention_store = self.step_store
179 | if self.save_global_store:
180 | with torch.no_grad():
181 | if len(self.global_store) == 0:
182 | self.global_store = self.step_store
183 | else:
184 | for key in self.global_store:
185 | for i in range(len(self.global_store[key])):
186 | self.global_store[key][i] += self.step_store[key][i].detach()
187 | self.step_store = self.get_empty_store()
188 | self.step_store = self.get_empty_store()
189 |
190 | def get_average_attention(self):
191 | average_attention = self.attention_store
192 | return average_attention
193 |
194 | def get_average_global_attention(self):
195 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
196 | self.attention_store}
197 | return average_attention
198 |
199 | def reset(self):
200 | super(AttentionStore, self).reset()
201 | self.step_store = self.get_empty_store()
202 | self.attention_store = {}
203 | self.global_store = {}
204 |
205 | def __init__(self, save_global_store=False):
206 | '''
207 | Initialize an empty AttentionStore
208 | :param step_index: used to visualize only a specific step in the diffusion process
209 | '''
210 | super(AttentionStore, self).__init__()
211 | self.save_global_store = save_global_store
212 | self.step_store = self.get_empty_store()
213 | self.attention_store = {}
214 | self.global_store = {}
215 | self.curr_step_index = 0
216 |
217 |
218 | def aggregate_attention(attention_store: AttentionStore,
219 | res: int,
220 | from_where: List[str],
221 | is_cross: bool,
222 | select: int) -> torch.Tensor:
223 | """ Aggregates the attention across the different layers and heads at the specified resolution. """
224 | out = []
225 | attention_maps = attention_store.get_average_attention()
226 | num_pixels = res ** 2
227 | for location in from_where:
228 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
229 | if item.shape[1] == num_pixels:
230 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select]
231 | out.append(cross_maps)
232 | out = torch.cat(out, dim=0)
233 | out = out.sum(0) / out.shape[0]
234 | return out
235 |
--------------------------------------------------------------------------------
/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List
3 | from PIL import Image
4 | import cv2
5 | import numpy as np
6 | import torch
7 |
8 | from utils import ptp_utils
9 | from utils.ptp_utils import AttentionStore, aggregate_attention
10 |
11 |
12 | def show_cross_attention(prompt: str,
13 | attention_store: AttentionStore,
14 | tokenizer,
15 | indices_to_alter: List[int],
16 | res: int,
17 | from_where: List[str],
18 | select: int = 0,
19 | orig_image=None):
20 | tokens = tokenizer.encode(prompt)
21 | decoder = tokenizer.decode
22 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select).detach().cpu()
23 | images = []
24 |
25 | # show spatial attention for indices of tokens to strengthen
26 | for i in range(len(tokens)):
27 | image = attention_maps[:, :, i]
28 | if i in indices_to_alter:
29 | image = show_image_relevance(image, orig_image)
30 | image = image.astype(np.uint8)
31 | image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2)))
32 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
33 | images.append(image)
34 |
35 | ptp_utils.view_images(np.stack(images, axis=0))
36 |
37 |
38 | def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16):
39 | # create heatmap from mask on image
40 | def show_cam_on_image(img, mask):
41 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
42 | heatmap = np.float32(heatmap) / 255
43 | cam = heatmap + np.float32(img)
44 | cam = cam / np.max(cam)
45 | return cam
46 |
47 | image = image.resize((relevnace_res ** 2, relevnace_res ** 2))
48 | image = np.array(image)
49 |
50 | image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1])
51 | image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu
52 | image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear')
53 | image_relevance = image_relevance.cpu() # send it back to cpu
54 | image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
55 | image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2)
56 | image = (image - image.min()) / (image.max() - image.min())
57 | vis = show_cam_on_image(image, image_relevance)
58 | vis = np.uint8(255 * vis)
59 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
60 | return vis
61 |
62 |
63 | def get_image_grid(images: List[Image.Image]) -> Image:
64 | num_images = len(images)
65 | cols = int(math.ceil(math.sqrt(num_images)))
66 | rows = int(math.ceil(num_images / cols))
67 | width, height = images[0].size
68 | grid_image = Image.new('RGB', (cols * width, rows * height))
69 | for i, img in enumerate(images):
70 | x = i % cols
71 | y = i // cols
72 | grid_image.paste(img, (x * width, y * height))
73 | return grid_image
74 |
--------------------------------------------------------------------------------