├── LICENSE ├── README.md ├── config.py ├── docs ├── explainability.jpg ├── results.jpg └── teaser.jpg ├── environment ├── environment.yaml └── requirements.txt ├── metrics ├── __init__.py ├── blip_captioning_and_clip_similarity.py ├── compute_clip_similarity.py └── imagenet_utils.py ├── notebooks ├── explain.ipynb └── generate_images.ipynb ├── pipeline_attend_and_excite.py ├── run.py └── utils ├── __init__.py ├── gaussian_smoothing.py ├── ptp_utils.py └── vis_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 AttendAndExcite 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models (SIGGRAPH 2023) 2 | 3 | > Hila Chefer*, Yuval Alaluf*, Yael Vinker, Lior Wolf, Daniel Cohen-Or 4 | > Tel Aviv University 5 | > \* Denotes equal contribution 6 | > 7 | > Recent text-to-image generative models have demonstrated an unparalleled ability to generate diverse and creative imagery guided by a target text prompt. While revolutionary, current state-of-the-art diffusion models may still fail in generating images that fully convey the semantics in the given text prompt. We analyze the publicly available Stable Diffusion model and assess the existence of catastrophic neglect, where the model fails to generate one or more of the subjects from the input prompt. Moreover, we find that in some cases the model also fails to correctly bind attributes (e.g., colors) to their corresponding subjects. To help mitigate these failure cases, we introduce the concept of Generative Semantic Nursing (GSN), where we seek to intervene in the generative process on the fly during inference time to improve the faithfulness of the generated images. Using an attention-based formulation of GSN, dubbed Attend-and-Excite, we guide the model to refine the cross-attention units to attend to all subject tokens in the text prompt and strengthen — or excite — their activations, encouraging the model to generate all subjects described in the text prompt. We compare our approach to alternative approaches and demonstrate that it conveys the desired concepts more faithfully across a range of text prompts. 8 | 9 | 10 | 11 | 12 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/hysts/Attend-and-Excite) 13 | [![Replicate](https://replicate.com/daanelson/attend-and-excite/badge)](https://replicate.com/daanelson/attend-and-excite) 14 | 15 |

16 | 17 |
18 | Given a pre-trained text-to-image diffusion model (e.g., Stable Diffusion) our method, Attend-and-Excite, guides the generative model to modify the cross-attention values during the image synthesis process to generate images that more faithfully depict the input text prompt. Stable Diffusion alone (top row) struggles to generate multiple objects (e.g., a horse and a dog). However, by incorporating Attend-and-Excite (bottom row) to strengthen the subject tokens (marked in blue), we achieve images that are more semantically faithful with respect to the input text prompts. 19 |

20 | 21 | ## Description 22 | Official implementation of our Attend-and-Excite paper. 23 | 24 | ## Setup 25 | 26 | ### Environment 27 | Our code builds on the requirement of the official [Stable Diffusion repository](https://github.com/CompVis/stable-diffusion). To set up their environment, please run: 28 | 29 | ``` 30 | conda env create -f environment/environment.yaml 31 | conda activate ldm 32 | ``` 33 | 34 | On top of these requirements, we add several requirements which can be found in `environment/requirements.txt`. These requirements will be installed in the above command. 35 | 36 | ### Hugging Face Diffusers Library 37 | Our code relies also on Hugging Face's [diffusers](https://github.com/huggingface/diffusers) library for downloading the Stable Diffusion v1.4 model. 38 | 39 | 40 | ## Usage 41 | 42 |

43 | 44 |
45 | Example generations outputted by Stable Diffusion with Attend-and-Excite. 46 |

47 | 48 | To generate an image, you can simply run the `run.py` script. For example, 49 | ``` 50 | python run.py --prompt "a cat and a dog" --seeds [0] --token_indices [2,5] 51 | ``` 52 | Notes: 53 | 54 | - To apply Attend-and-Excite on Stable Diffusion 2.1, specify: `--sd_2_1 True` 55 | - You may run multiple seeds by passing a list of seeds. For example, `--seeds [0,1,2,3]`. 56 | - If you do not provide a list of which token indices to alter using `--token_indices`, we will split the text according to the Stable Diffusion's tokenizer and display the index of each token. You will then be able to input which indices you wish to alter. 57 | - If you wish to run the standard Stable Diffusion model without Attend-and-Excite, you can do so by passing `--run_standard_sd True`. 58 | - All parameters are defined in `config.py` and are set to their defaults according to the official paper. 59 | 60 | All generated images will be saved to the path `"{config.output_path}/{prompt}"`. We will also save a grid of all images (in the case of multiple seeds) under `config.output_path`. 61 | 62 | ### Float16 Precision 63 | When loading the Stable Diffusion model, you can use `torch.float16` in order to use less memory and attain faster inference: 64 | ```python 65 | stable = AttendAndExcitePipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to(device) 66 | ``` 67 | Note that this may result in a slight degradation of results in some cases. 68 | 69 | ## Notebooks 70 | We provide Jupyter notebooks to reproduce the results from the paper for image generation and explainability via the 71 | cross-attention maps. 72 | 73 |

74 | 75 |
76 | Example cross-attention visualizations. 77 |

78 | 79 | ### Generation 80 | `notebooks/generate_images.ipynb` enables image generation using a free-form text prompt with and without Attend-and-Excite. 81 | 82 | ### Explainability 83 | `notebooks/explain.ipynb` produces a comparison of the cross-attention maps before and after applying Attend-and-Excite 84 | as seen in the illustration above. 85 | This notebook can be used to provide an explanation for the generations produced by Attend-and-Excite. 86 | 87 | ## Metrics 88 | In `metrics/` we provide code needed to reproduce the quantitative experiments presented in the paper: 89 | 1. In `compute_clip_similarity.py`, we provide the code needed for computing the image-based CLIP similarities. Here, we compute the CLIP-space similarities between the generated images and the guiding text prompt. 90 | 2. In `blip_captioning_and_clip_similarity.py`, we provide the code needed for computing the text-based CLIP similarities. Here, we generate captions for each generated image using BLIP and compute the CLIP-space similarities between the generated captions and the guiding text prompt. 91 | - Note: to run this script you need to install the `lavis` library. This can be done using `pip install lavis`. 92 | 93 | To run the scripts, you simply need to pass the output directory containing the generated images. The direcory structure should be as follows: 94 | ``` 95 | outputs/ 96 | |-- prompt_1/ 97 | | |-- 0.png 98 | | |-- 1.png 99 | | |-- ... 100 | | |-- 64.png 101 | |-- prompt_2/ 102 | | |-- 0.png 103 | | |-- 1.png 104 | | |-- ... 105 | | |-- 64.png 106 | ... 107 | ``` 108 | The scripts will iterate through all the prompt outputs provided in the root output directory and aggregate results across all images. 109 | 110 | The metrics will be saved to a `json` file under the path specified by `--metrics_save_path`. 111 | 112 | ### Evaluation Prompts 113 | The prompts used in our quantitative evaluations can be found [here](https://github.com/AttendAndExcite/Attend-and-Excite/files/11336216/a.e_prompts.txt). 114 | 115 | ## Acknowledgements 116 | This code is builds on the code from the [diffusers](https://github.com/huggingface/diffusers) library as well as the [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/) codebase. 117 | 118 | ## Citation 119 | If you use this code for your research, please cite the following work: 120 | ``` 121 | @misc{chefer2023attendandexcite, 122 | title={Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models}, 123 | author={Hila Chefer and Yuval Alaluf and Yael Vinker and Lior Wolf and Daniel Cohen-Or}, 124 | year={2023}, 125 | eprint={2301.13826}, 126 | archivePrefix={arXiv}, 127 | primaryClass={cs.CV} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from typing import Dict, List 4 | 5 | 6 | @dataclass 7 | class RunConfig: 8 | # Guiding text prompt 9 | prompt: str 10 | # Whether to use Stable Diffusion v2.1 11 | sd_2_1: bool = False 12 | # Which token indices to alter with attend-and-excite 13 | token_indices: List[int] = None 14 | # Which random seeds to use when generating 15 | seeds: List[int] = field(default_factory=lambda: [42]) 16 | # Path to save all outputs to 17 | output_path: Path = Path('./outputs') 18 | # Number of denoising steps 19 | n_inference_steps: int = 50 20 | # Text guidance scale 21 | guidance_scale: float = 7.5 22 | # Number of denoising steps to apply attend-and-excite 23 | max_iter_to_alter: int = 25 24 | # Resolution of UNet to compute attention maps over 25 | attention_res: int = 16 26 | # Whether to run standard SD or attend-and-excite 27 | run_standard_sd: bool = False 28 | # Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in 29 | thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8}) 30 | # Scale factor for updating the denoised latent z_t 31 | scale_factor: int = 20 32 | # Start and end values used for scaling the scale factor - decays linearly with the denoising timestep 33 | scale_range: tuple = field(default_factory=lambda: (1.0, 0.5)) 34 | # Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token 35 | smooth_attentions: bool = True 36 | # Standard deviation for the Gaussian smoothing 37 | sigma: float = 0.5 38 | # Kernel size for the Gaussian smoothing 39 | kernel_size: int = 3 40 | # Whether to save cross attention maps for the final results 41 | save_cross_attention_maps: bool = False 42 | 43 | def __post_init__(self): 44 | self.output_path.mkdir(exist_ok=True, parents=True) 45 | -------------------------------------------------------------------------------- /docs/explainability.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/docs/explainability.jpg -------------------------------------------------------------------------------- /docs/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/docs/results.jpg -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/docs/teaser.jpg -------------------------------------------------------------------------------- /environment/environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - diffusers 15 | - opencv-python==4.1.2.30 16 | - pudb==2019.2 17 | - invisible-watermark 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.4.2 21 | - omegaconf==2.1.1 22 | - test-tube>=0.7.5 23 | - streamlit>=0.73.1 24 | - einops==0.3.0 25 | - torch-fidelity==0.3.0 26 | - torchmetrics==0.6.0 27 | - kornia==0.6 28 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 29 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 30 | - -r requirements.txt 31 | -------------------------------------------------------------------------------- /environment/requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | opencv-python 3 | ipywidgets 4 | matplotlib 5 | pyrallis 6 | torch==1.12.0 7 | diffusers==0.12.1 8 | transformers==4.26.0 9 | jupyter 10 | accelerate 11 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/blip_captioning_and_clip_similarity.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | import clip 7 | import numpy as np 8 | import pyrallis 9 | import torch 10 | from PIL import Image 11 | from lavis.models import load_model_and_preprocess 12 | from tqdm import tqdm 13 | 14 | sys.path.append(".") 15 | sys.path.append("..") 16 | 17 | from metrics.imagenet_utils import get_embedding_for_prompt, imagenet_templates 18 | 19 | 20 | @dataclass 21 | class EvalConfig: 22 | output_path: Path = Path("./outputs/") 23 | metrics_save_path: Path = Path("./metrics/") 24 | 25 | def __post_init__(self): 26 | self.metrics_save_path.mkdir(parents=True, exist_ok=True) 27 | 28 | 29 | @pyrallis.wrap() 30 | def run(config: EvalConfig): 31 | print("Loading CLIP model...") 32 | device = torch.device("cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu") 33 | model, preprocess = clip.load("ViT-B/16", device) 34 | model.eval() 35 | print("Done.") 36 | 37 | print("Loading BLIP model...") 38 | blip_model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", 39 | is_eval=True, device=device) 40 | print("Done.") 41 | 42 | prompts = [p.name for p in config.output_path.glob("*") if p.is_dir()] 43 | print(f"Running on {len(prompts)} prompts...") 44 | 45 | results_per_prompt = {} 46 | for prompt in tqdm(prompts): 47 | print(f'Running on: "{prompt}"') 48 | 49 | # get all images for the given prompt 50 | image_paths = [p for p in (config.output_path / prompt).rglob('*') if p.suffix in ['.png', '.jpg']] 51 | images = [Image.open(p) for p in image_paths] 52 | image_names = [p.name for p in image_paths] 53 | 54 | with torch.no_grad(): 55 | # extract prompt embeddings 56 | prompt_features = get_embedding_for_prompt(model, prompt, templates=imagenet_templates) 57 | 58 | # extract blip captions and embeddings 59 | blip_input_images = [vis_processors["eval"](image).unsqueeze(0).to(device) for image in images] 60 | blip_captions = [blip_model.generate({"image": image})[0] for image in blip_input_images] 61 | texts = [clip.tokenize([text]).cuda() for text in blip_captions] 62 | caption_embeddings = [model.encode_text(t) for t in texts] 63 | caption_embeddings = [embedding / embedding.norm(dim=-1, keepdim=True) for embedding in caption_embeddings] 64 | 65 | text_similarities = [(caption_embedding.float() @ prompt_features.T).item() 66 | for caption_embedding in caption_embeddings] 67 | 68 | results_per_prompt[prompt] = { 69 | 'text_similarities': text_similarities, 70 | 'captions': blip_captions, 71 | 'image_names': image_names, 72 | } 73 | 74 | # aggregate results 75 | total_average, total_std = aggregate_text_similarities(results_per_prompt) 76 | aggregated_results = { 77 | 'average_similarity': total_average, 78 | 'std_similarity': total_std, 79 | } 80 | 81 | with open(config.metrics_save_path / "blip_raw_metrics.json", 'w') as f: 82 | json.dump(results_per_prompt, f, sort_keys=True, indent=4) 83 | with open(config.metrics_save_path / "blip_aggregated_metrics.json", 'w') as f: 84 | json.dump(aggregated_results, f, sort_keys=True, indent=4) 85 | 86 | 87 | def aggregate_text_similarities(result_dict): 88 | all_averages = [result_dict[prompt]['text_similarities'] for prompt in result_dict] 89 | all_averages = np.array(all_averages).flatten() 90 | total_average = np.average(all_averages) 91 | total_std = np.std(all_averages) 92 | return total_average, total_std 93 | 94 | 95 | if __name__ == '__main__': 96 | run() 97 | -------------------------------------------------------------------------------- /metrics/compute_clip_similarity.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | import clip 7 | import numpy as np 8 | import pyrallis 9 | import torch 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from metrics.imagenet_utils import get_embedding_for_prompt, imagenet_templates 17 | 18 | 19 | @dataclass 20 | class EvalConfig: 21 | output_path: Path = Path("./outputs/") 22 | metrics_save_path: Path = Path("./metrics/") 23 | 24 | def __post_init__(self): 25 | self.metrics_save_path.mkdir(parents=True, exist_ok=True) 26 | 27 | 28 | @pyrallis.wrap() 29 | def run(config: EvalConfig): 30 | print("Loading CLIP model...") 31 | device = torch.device("cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu") 32 | model, preprocess = clip.load("ViT-B/16", device) 33 | model.eval() 34 | print("Done.") 35 | 36 | prompts = [p.name for p in config.output_path.glob("*") if p.is_dir()] 37 | print(f"Running on {len(prompts)} prompts...") 38 | 39 | results_per_prompt = {} 40 | for prompt in tqdm(prompts): 41 | 42 | print(f'Running on: "{prompt}"') 43 | 44 | # get all images for the given prompt 45 | image_paths = [p for p in (config.output_path / prompt).rglob('*') if p.suffix in ['.png', '.jpg']] 46 | images = [Image.open(p) for p in image_paths] 47 | image_names = [p.name for p in image_paths] 48 | queries = [preprocess(image).unsqueeze(0).to(device) for image in images] 49 | 50 | with torch.no_grad(): 51 | 52 | # split prompt into first and second halves 53 | if ' and ' in prompt: 54 | prompt_parts = prompt.split(' and ') 55 | elif ' with ' in prompt: 56 | prompt_parts = prompt.split(' with ') 57 | else: 58 | print(f"Unable to split prompt: {prompt}. " 59 | f"Looking for 'and' or 'with' for splitting! Skipping!") 60 | continue 61 | 62 | # extract texture features 63 | full_text_features = get_embedding_for_prompt(model, prompt, templates=imagenet_templates) 64 | first_half_features = get_embedding_for_prompt(model, prompt_parts[0], templates=imagenet_templates) 65 | second_half_features = get_embedding_for_prompt(model, prompt_parts[1], templates=imagenet_templates) 66 | 67 | # extract image features 68 | images_features = [model.encode_image(image) for image in queries] 69 | images_features = [feats / feats.norm(dim=-1, keepdim=True) for feats in images_features] 70 | 71 | # compute similarities 72 | full_text_similarities = [(feat.float() @ full_text_features.T).item() for feat in images_features] 73 | first_half_similarities = [(feat.float() @ first_half_features.T).item() for feat in images_features] 74 | second_half_similarities = [(feat.float() @ second_half_features.T).item() for feat in images_features] 75 | 76 | results_per_prompt[prompt] = { 77 | 'full_text': full_text_similarities, 78 | 'first_half': first_half_similarities, 79 | 'second_half': second_half_similarities, 80 | 'image_names': image_names, 81 | } 82 | 83 | # aggregate results 84 | aggregated_results = { 85 | 'full_text_aggregation': aggregate_by_full_text(results_per_prompt), 86 | 'min_first_second_aggregation': aggregate_by_min_half(results_per_prompt), 87 | } 88 | 89 | with open(config.metrics_save_path / "clip_raw_metrics.json", 'w') as f: 90 | json.dump(results_per_prompt, f, sort_keys=True, indent=4) 91 | with open(config.metrics_save_path / "clip_aggregated_metrics.json", 'w') as f: 92 | json.dump(aggregated_results, f, sort_keys=True, indent=4) 93 | 94 | 95 | def aggregate_by_min_half(d): 96 | """ Aggregate results for the minimum similarity score for each prompt. """ 97 | min_per_half_res = [[min(a, b) for a, b in zip(d[prompt]["first_half"], d[prompt]["second_half"])] for prompt in d] 98 | min_per_half_res = np.array(min_per_half_res).flatten() 99 | return np.average(min_per_half_res) 100 | 101 | 102 | def aggregate_by_full_text(d): 103 | """ Aggregate results for the full text similarity for each prompt. """ 104 | full_text_res = [v['full_text'] for v in d.values()] 105 | full_text_res = np.array(full_text_res).flatten() 106 | return np.average(full_text_res) 107 | 108 | 109 | if __name__ == '__main__': 110 | run() 111 | -------------------------------------------------------------------------------- /metrics/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | import clip 2 | 3 | imagenet_templates = [ 4 | 'a bad photo of a {}.', 5 | 'a photo of many {}.', 6 | 'a sculpture of a {}.', 7 | 'a photo of the hard to see {}.', 8 | 'a low resolution photo of the {}.', 9 | 'a rendering of a {}.', 10 | 'graffiti of a {}.', 11 | 'a bad photo of the {}.', 12 | 'a cropped photo of the {}.', 13 | 'a tattoo of a {}.', 14 | 'the embroidered {}.', 15 | 'a photo of a hard to see {}.', 16 | 'a bright photo of a {}.', 17 | 'a photo of a clean {}.', 18 | 'a photo of a dirty {}.', 19 | 'a dark photo of the {}.', 20 | 'a drawing of a {}.', 21 | 'a photo of my {}.', 22 | 'the plastic {}.', 23 | 'a photo of the cool {}.', 24 | 'a close-up photo of a {}.', 25 | 'a black and white photo of the {}.', 26 | 'a painting of the {}.', 27 | 'a painting of a {}.', 28 | 'a pixelated photo of the {}.', 29 | 'a sculpture of the {}.', 30 | 'a bright photo of the {}.', 31 | 'a cropped photo of a {}.', 32 | 'a plastic {}.', 33 | 'a photo of the dirty {}.', 34 | 'a jpeg corrupted photo of a {}.', 35 | 'a blurry photo of the {}.', 36 | 'a photo of the {}.', 37 | 'a good photo of the {}.', 38 | 'a rendering of the {}.', 39 | 'a {} in a video game.', 40 | 'a photo of one {}.', 41 | 'a doodle of a {}.', 42 | 'a close-up photo of the {}.', 43 | 'a photo of a {}.', 44 | 'the origami {}.', 45 | 'the {} in a video game.', 46 | 'a sketch of a {}.', 47 | 'a doodle of the {}.', 48 | 'a origami {}.', 49 | 'a low resolution photo of a {}.', 50 | 'the toy {}.', 51 | 'a rendition of the {}.', 52 | 'a photo of the clean {}.', 53 | 'a photo of a large {}.', 54 | 'a rendition of a {}.', 55 | 'a photo of a nice {}.', 56 | 'a photo of a weird {}.', 57 | 'a blurry photo of a {}.', 58 | 'a cartoon {}.', 59 | 'art of a {}.', 60 | 'a sketch of the {}.', 61 | 'a embroidered {}.', 62 | 'a pixelated photo of a {}.', 63 | 'itap of the {}.', 64 | 'a jpeg corrupted photo of the {}.', 65 | 'a good photo of a {}.', 66 | 'a plushie {}.', 67 | 'a photo of the nice {}.', 68 | 'a photo of the small {}.', 69 | 'a photo of the weird {}.', 70 | 'the cartoon {}.', 71 | 'art of the {}.', 72 | 'a drawing of the {}.', 73 | 'a photo of the large {}.', 74 | 'a black and white photo of a {}.', 75 | 'the plushie {}.', 76 | 'a dark photo of a {}.', 77 | 'itap of a {}.', 78 | 'graffiti of the {}.', 79 | 'a toy {}.', 80 | 'itap of my {}.', 81 | 'a photo of a cool {}.', 82 | 'a photo of a small {}.', 83 | 'a tattoo of the {}.', 84 | ] 85 | 86 | 87 | def get_embedding_for_prompt(model, prompt, templates): 88 | texts = [template.format(prompt) for template in templates] # format with class 89 | texts = [t.replace('a a', 'a') for t in texts] # remove double a's 90 | texts = [t.replace('the a', 'a') for t in texts] # remove double a's 91 | texts = clip.tokenize(texts).cuda() # tokenize 92 | class_embeddings = model.encode_text(texts) # embed with text encoder 93 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 94 | class_embedding = class_embeddings.mean(dim=0) 95 | class_embedding /= class_embedding.norm() 96 | return class_embedding.float() 97 | -------------------------------------------------------------------------------- /pipeline_attend_and_excite.py: -------------------------------------------------------------------------------- 1 | 2 | import inspect 3 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | from packaging import version 10 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 11 | 12 | from diffusers.configuration_utils import FrozenDict 13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 14 | from diffusers.schedulers import KarrasDiffusionSchedulers 15 | from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring 16 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 17 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 18 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 19 | 20 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 21 | 22 | from utils.gaussian_smoothing import GaussianSmoothing 23 | from utils.ptp_utils import AttentionStore, aggregate_attention 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | class AttendAndExcitePipeline(StableDiffusionPipeline): 28 | r""" 29 | Pipeline for text-to-image generation using Stable Diffusion. 30 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 31 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 32 | Args: 33 | vae ([`AutoencoderKL`]): 34 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 35 | text_encoder ([`CLIPTextModel`]): 36 | Frozen text-encoder. Stable Diffusion uses the text portion of 37 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 38 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 39 | tokenizer (`CLIPTokenizer`): 40 | Tokenizer of class 41 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 42 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 43 | scheduler ([`SchedulerMixin`]): 44 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 45 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 46 | safety_checker ([`StableDiffusionSafetyChecker`]): 47 | Classification module that estimates whether generated images could be considered offensive or harmful. 48 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 49 | feature_extractor ([`CLIPFeatureExtractor`]): 50 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 51 | """ 52 | _optional_components = ["safety_checker", "feature_extractor"] 53 | 54 | def _encode_prompt( 55 | self, 56 | prompt, 57 | device, 58 | num_images_per_prompt, 59 | do_classifier_free_guidance, 60 | negative_prompt=None, 61 | prompt_embeds: Optional[torch.FloatTensor] = None, 62 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 63 | ): 64 | r""" 65 | Encodes the prompt into text encoder hidden states. 66 | Args: 67 | prompt (`str` or `List[str]`, *optional*): 68 | prompt to be encoded 69 | device: (`torch.device`): 70 | torch device 71 | num_images_per_prompt (`int`): 72 | number of images that should be generated per prompt 73 | do_classifier_free_guidance (`bool`): 74 | whether to use classifier free guidance or not 75 | negative_ prompt (`str` or `List[str]`, *optional*): 76 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 77 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 78 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 79 | prompt_embeds (`torch.FloatTensor`, *optional*): 80 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 81 | provided, text embeddings will be generated from `prompt` input argument. 82 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 83 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 84 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 85 | argument. 86 | """ 87 | if prompt is not None and isinstance(prompt, str): 88 | batch_size = 1 89 | elif prompt is not None and isinstance(prompt, list): 90 | batch_size = len(prompt) 91 | else: 92 | batch_size = prompt_embeds.shape[0] 93 | 94 | if prompt_embeds is None: 95 | text_inputs = self.tokenizer( 96 | prompt, 97 | padding="max_length", 98 | max_length=self.tokenizer.model_max_length, 99 | truncation=True, 100 | return_tensors="pt", 101 | ) 102 | text_input_ids = text_inputs.input_ids 103 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 104 | 105 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 106 | text_input_ids, untruncated_ids 107 | ): 108 | removed_text = self.tokenizer.batch_decode( 109 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 110 | ) 111 | logger.warning( 112 | "The following part of your input was truncated because CLIP can only handle sequences up to" 113 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 114 | ) 115 | 116 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 117 | attention_mask = text_inputs.attention_mask.to(device) 118 | else: 119 | attention_mask = None 120 | 121 | prompt_embeds = self.text_encoder( 122 | text_input_ids.to(device), 123 | attention_mask=attention_mask, 124 | ) 125 | prompt_embeds = prompt_embeds[0] 126 | 127 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 128 | 129 | bs_embed, seq_len, _ = prompt_embeds.shape 130 | # duplicate text embeddings for each generation per prompt, using mps friendly method 131 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 132 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 133 | 134 | # get unconditional embeddings for classifier free guidance 135 | if do_classifier_free_guidance and negative_prompt_embeds is None: 136 | uncond_tokens: List[str] 137 | if negative_prompt is None: 138 | uncond_tokens = [""] * batch_size 139 | elif type(prompt) is not type(negative_prompt): 140 | raise TypeError( 141 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 142 | f" {type(prompt)}." 143 | ) 144 | elif isinstance(negative_prompt, str): 145 | uncond_tokens = [negative_prompt] 146 | elif batch_size != len(negative_prompt): 147 | raise ValueError( 148 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 149 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 150 | " the batch size of `prompt`." 151 | ) 152 | else: 153 | uncond_tokens = negative_prompt 154 | 155 | max_length = prompt_embeds.shape[1] 156 | uncond_input = self.tokenizer( 157 | uncond_tokens, 158 | padding="max_length", 159 | max_length=max_length, 160 | truncation=True, 161 | return_tensors="pt", 162 | ) 163 | 164 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 165 | attention_mask = uncond_input.attention_mask.to(device) 166 | else: 167 | attention_mask = None 168 | 169 | negative_prompt_embeds = self.text_encoder( 170 | uncond_input.input_ids.to(device), 171 | attention_mask=attention_mask, 172 | ) 173 | negative_prompt_embeds = negative_prompt_embeds[0] 174 | 175 | if do_classifier_free_guidance: 176 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 177 | seq_len = negative_prompt_embeds.shape[1] 178 | 179 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 180 | 181 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 182 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 183 | 184 | # For classifier free guidance, we need to do two forward passes. 185 | # Here we concatenate the unconditional and text embeddings into a single batch 186 | # to avoid doing two forward passes 187 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 188 | 189 | return text_inputs, prompt_embeds 190 | 191 | def _compute_max_attention_per_index(self, 192 | attention_maps: torch.Tensor, 193 | indices_to_alter: List[int], 194 | smooth_attentions: bool = False, 195 | sigma: float = 0.5, 196 | kernel_size: int = 3, 197 | normalize_eot: bool = False) -> List[torch.Tensor]: 198 | """ Computes the maximum attention value for each of the tokens we wish to alter. """ 199 | last_idx = -1 200 | if normalize_eot: 201 | prompt = self.prompt 202 | if isinstance(self.prompt, list): 203 | prompt = self.prompt[0] 204 | last_idx = len(self.tokenizer(prompt)['input_ids']) - 1 205 | attention_for_text = attention_maps[:, :, 1:last_idx] 206 | attention_for_text *= 100 207 | attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) 208 | 209 | # Shift indices since we removed the first token 210 | indices_to_alter = [index - 1 for index in indices_to_alter] 211 | 212 | # Extract the maximum values 213 | max_indices_list = [] 214 | for i in indices_to_alter: 215 | image = attention_for_text[:, :, i] 216 | if smooth_attentions: 217 | smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() 218 | input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') 219 | image = smoothing(input).squeeze(0).squeeze(0) 220 | max_indices_list.append(image.max()) 221 | return max_indices_list 222 | 223 | def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionStore, 224 | indices_to_alter: List[int], 225 | attention_res: int = 16, 226 | smooth_attentions: bool = False, 227 | sigma: float = 0.5, 228 | kernel_size: int = 3, 229 | normalize_eot: bool = False): 230 | """ Aggregates the attention for each token and computes the max activation value for each token to alter. """ 231 | attention_maps = aggregate_attention( 232 | attention_store=attention_store, 233 | res=attention_res, 234 | from_where=("up", "down", "mid"), 235 | is_cross=True, 236 | select=0) 237 | max_attention_per_index = self._compute_max_attention_per_index( 238 | attention_maps=attention_maps, 239 | indices_to_alter=indices_to_alter, 240 | smooth_attentions=smooth_attentions, 241 | sigma=sigma, 242 | kernel_size=kernel_size, 243 | normalize_eot=normalize_eot) 244 | return max_attention_per_index 245 | 246 | @staticmethod 247 | def _compute_loss(max_attention_per_index: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor: 248 | """ Computes the attend-and-excite loss using the maximum attention value for each token. """ 249 | losses = [max(0, 1. - curr_max) for curr_max in max_attention_per_index] 250 | loss = max(losses) 251 | if return_losses: 252 | return loss, losses 253 | else: 254 | return loss 255 | 256 | @staticmethod 257 | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: 258 | """ Update the latent according to the computed loss. """ 259 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] 260 | latents = latents - step_size * grad_cond 261 | return latents 262 | 263 | def _perform_iterative_refinement_step(self, 264 | latents: torch.Tensor, 265 | indices_to_alter: List[int], 266 | loss: torch.Tensor, 267 | threshold: float, 268 | text_embeddings: torch.Tensor, 269 | text_input, 270 | attention_store: AttentionStore, 271 | step_size: float, 272 | t: int, 273 | attention_res: int = 16, 274 | smooth_attentions: bool = True, 275 | sigma: float = 0.5, 276 | kernel_size: int = 3, 277 | max_refinement_steps: int = 20, 278 | normalize_eot: bool = False): 279 | """ 280 | Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent 281 | code according to our loss objective until the given threshold is reached for all tokens. 282 | """ 283 | iteration = 0 284 | target_loss = max(0, 1. - threshold) 285 | while loss > target_loss: 286 | iteration += 1 287 | 288 | latents = latents.clone().detach().requires_grad_(True) 289 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 290 | self.unet.zero_grad() 291 | 292 | # Get max activation value for each subject token 293 | max_attention_per_index = self._aggregate_and_get_max_attention_per_token( 294 | attention_store=attention_store, 295 | indices_to_alter=indices_to_alter, 296 | attention_res=attention_res, 297 | smooth_attentions=smooth_attentions, 298 | sigma=sigma, 299 | kernel_size=kernel_size, 300 | normalize_eot=normalize_eot 301 | ) 302 | 303 | loss, losses = self._compute_loss(max_attention_per_index, return_losses=True) 304 | 305 | if loss != 0: 306 | latents = self._update_latent(latents, loss, step_size) 307 | 308 | with torch.no_grad(): 309 | noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample 310 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 311 | 312 | try: 313 | low_token = np.argmax([l.item() if type(l) != int else l for l in losses]) 314 | except Exception as e: 315 | print(e) # catch edge case :) 316 | low_token = np.argmax(losses) 317 | 318 | low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]]) 319 | print(f'\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index[low_token]}') 320 | 321 | if iteration >= max_refinement_steps: 322 | print(f'\t Exceeded max number of iterations ({max_refinement_steps})! ' 323 | f'Finished with a max attention of {max_attention_per_index[low_token]}') 324 | break 325 | 326 | # Run one more time but don't compute gradients and update the latents. 327 | # We just need to compute the new loss - the grad update will occur below 328 | latents = latents.clone().detach().requires_grad_(True) 329 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 330 | self.unet.zero_grad() 331 | 332 | # Get max activation value for each subject token 333 | max_attention_per_index = self._aggregate_and_get_max_attention_per_token( 334 | attention_store=attention_store, 335 | indices_to_alter=indices_to_alter, 336 | attention_res=attention_res, 337 | smooth_attentions=smooth_attentions, 338 | sigma=sigma, 339 | kernel_size=kernel_size, 340 | normalize_eot=normalize_eot) 341 | loss, losses = self._compute_loss(max_attention_per_index, return_losses=True) 342 | print(f"\t Finished with loss of: {loss}") 343 | return loss, latents, max_attention_per_index 344 | 345 | @torch.no_grad() 346 | def __call__( 347 | self, 348 | prompt: Union[str, List[str]], 349 | attention_store: AttentionStore, 350 | indices_to_alter: List[int], 351 | attention_res: int = 16, 352 | height: Optional[int] = None, 353 | width: Optional[int] = None, 354 | num_inference_steps: int = 50, 355 | guidance_scale: float = 7.5, 356 | negative_prompt: Optional[Union[str, List[str]]] = None, 357 | num_images_per_prompt: Optional[int] = 1, 358 | eta: float = 0.0, 359 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 360 | latents: Optional[torch.FloatTensor] = None, 361 | prompt_embeds: Optional[torch.FloatTensor] = None, 362 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 363 | output_type: Optional[str] = "pil", 364 | return_dict: bool = True, 365 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 366 | callback_steps: Optional[int] = 1, 367 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 368 | max_iter_to_alter: Optional[int] = 25, 369 | run_standard_sd: bool = False, 370 | thresholds: Optional[dict] = {0: 0.05, 10: 0.5, 20: 0.8}, 371 | scale_factor: int = 20, 372 | scale_range: Tuple[float, float] = (1., 0.5), 373 | smooth_attentions: bool = True, 374 | sigma: float = 0.5, 375 | kernel_size: int = 3, 376 | sd_2_1: bool = False, 377 | ): 378 | r""" 379 | Function invoked when calling the pipeline for generation. 380 | Args: 381 | prompt (`str` or `List[str]`, *optional*): 382 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 383 | instead. 384 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 385 | The height in pixels of the generated image. 386 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 387 | The width in pixels of the generated image. 388 | num_inference_steps (`int`, *optional*, defaults to 50): 389 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 390 | expense of slower inference. 391 | guidance_scale (`float`, *optional*, defaults to 7.5): 392 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 393 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 394 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 395 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 396 | usually at the expense of lower image quality. 397 | negative_prompt (`str` or `List[str]`, *optional*): 398 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 399 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 400 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 401 | num_images_per_prompt (`int`, *optional*, defaults to 1): 402 | The number of images to generate per prompt. 403 | eta (`float`, *optional*, defaults to 0.0): 404 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 405 | [`schedulers.DDIMScheduler`], will be ignored for others. 406 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 407 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 408 | to make generation deterministic. 409 | latents (`torch.FloatTensor`, *optional*): 410 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 411 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 412 | tensor will ge generated by sampling using the supplied random `generator`. 413 | prompt_embeds (`torch.FloatTensor`, *optional*): 414 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 415 | provided, text embeddings will be generated from `prompt` input argument. 416 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 417 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 418 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 419 | argument. 420 | output_type (`str`, *optional*, defaults to `"pil"`): 421 | The output format of the generate image. Choose between 422 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 423 | return_dict (`bool`, *optional*, defaults to `True`): 424 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 425 | plain tuple. 426 | callback (`Callable`, *optional*): 427 | A function that will be called every `callback_steps` steps during inference. The function will be 428 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 429 | callback_steps (`int`, *optional*, defaults to 1): 430 | The frequency at which the `callback` function will be called. If not specified, the callback will be 431 | called at every step. 432 | cross_attention_kwargs (`dict`, *optional*): 433 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 434 | `self.processor` in 435 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 436 | Examples: 437 | Returns: 438 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 439 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 440 | When returning a tuple, the first element is a list with the generated images, and the second element is a 441 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 442 | (nsfw) content, according to the `safety_checker`. 443 | :type attention_store: object 444 | """ 445 | # 0. Default height and width to unet 446 | height = height or self.unet.config.sample_size * self.vae_scale_factor 447 | width = width or self.unet.config.sample_size * self.vae_scale_factor 448 | 449 | # 1. Check inputs. Raise error if not correct 450 | self.check_inputs( 451 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds 452 | ) 453 | 454 | # 2. Define call parameters 455 | self.prompt = prompt 456 | if prompt is not None and isinstance(prompt, str): 457 | batch_size = 1 458 | elif prompt is not None and isinstance(prompt, list): 459 | batch_size = len(prompt) 460 | else: 461 | batch_size = prompt_embeds.shape[0] 462 | 463 | device = self._execution_device 464 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 465 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 466 | # corresponds to doing no classifier free guidance. 467 | do_classifier_free_guidance = guidance_scale > 1.0 468 | 469 | # 3. Encode input prompt 470 | text_inputs, prompt_embeds = self._encode_prompt( 471 | prompt, 472 | device, 473 | num_images_per_prompt, 474 | do_classifier_free_guidance, 475 | negative_prompt, 476 | prompt_embeds=prompt_embeds, 477 | negative_prompt_embeds=negative_prompt_embeds, 478 | ) 479 | 480 | # 4. Prepare timesteps 481 | self.scheduler.set_timesteps(num_inference_steps, device=device) 482 | timesteps = self.scheduler.timesteps 483 | 484 | # 5. Prepare latent variables 485 | num_channels_latents = self.unet.in_channels 486 | latents = self.prepare_latents( 487 | batch_size * num_images_per_prompt, 488 | num_channels_latents, 489 | height, 490 | width, 491 | prompt_embeds.dtype, 492 | device, 493 | generator, 494 | latents, 495 | ) 496 | 497 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 498 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 499 | 500 | scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps)) 501 | 502 | if max_iter_to_alter is None: 503 | max_iter_to_alter = len(self.scheduler.timesteps) + 1 504 | 505 | # 7. Denoising loop 506 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 507 | with self.progress_bar(total=num_inference_steps) as progress_bar: 508 | for i, t in enumerate(timesteps): 509 | 510 | with torch.enable_grad(): 511 | 512 | latents = latents.clone().detach().requires_grad_(True) 513 | 514 | # Forward pass of denoising with text conditioning 515 | noise_pred_text = self.unet(latents, t, 516 | encoder_hidden_states=prompt_embeds[1].unsqueeze(0), cross_attention_kwargs=cross_attention_kwargs).sample 517 | self.unet.zero_grad() 518 | 519 | # Get max activation value for each subject token 520 | max_attention_per_index = self._aggregate_and_get_max_attention_per_token( 521 | attention_store=attention_store, 522 | indices_to_alter=indices_to_alter, 523 | attention_res=attention_res, 524 | smooth_attentions=smooth_attentions, 525 | sigma=sigma, 526 | kernel_size=kernel_size, 527 | normalize_eot=sd_2_1) 528 | 529 | if not run_standard_sd: 530 | 531 | loss = self._compute_loss(max_attention_per_index=max_attention_per_index) 532 | 533 | # If this is an iterative refinement step, verify we have reached the desired threshold for all 534 | if i in thresholds.keys() and loss > 1. - thresholds[i]: 535 | del noise_pred_text 536 | torch.cuda.empty_cache() 537 | loss, latents, max_attention_per_index = self._perform_iterative_refinement_step( 538 | latents=latents, 539 | indices_to_alter=indices_to_alter, 540 | loss=loss, 541 | threshold=thresholds[i], 542 | text_embeddings=prompt_embeds, 543 | text_input=text_inputs, 544 | attention_store=attention_store, 545 | step_size=scale_factor * np.sqrt(scale_range[i]), 546 | t=t, 547 | attention_res=attention_res, 548 | smooth_attentions=smooth_attentions, 549 | sigma=sigma, 550 | kernel_size=kernel_size, 551 | normalize_eot=sd_2_1) 552 | 553 | # Perform gradient update 554 | if i < max_iter_to_alter: 555 | loss = self._compute_loss(max_attention_per_index=max_attention_per_index) 556 | if loss != 0: 557 | latents = self._update_latent(latents=latents, loss=loss, 558 | step_size=scale_factor * np.sqrt(scale_range[i])) 559 | print(f'Iteration {i} | Loss: {loss:0.4f}') 560 | 561 | # expand the latents if we are doing classifier free guidance 562 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 563 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 564 | 565 | # predict the noise residual 566 | noise_pred = self.unet( 567 | latent_model_input, 568 | t, 569 | encoder_hidden_states=prompt_embeds, 570 | cross_attention_kwargs=cross_attention_kwargs, 571 | ).sample 572 | 573 | # perform guidance 574 | if do_classifier_free_guidance: 575 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 576 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 577 | 578 | # compute the previous noisy sample x_t -> x_t-1 579 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 580 | 581 | # call the callback, if provided 582 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 583 | progress_bar.update() 584 | if callback is not None and i % callback_steps == 0: 585 | callback(i, t, latents) 586 | 587 | # 8. Post-processing 588 | image = self.decode_latents(latents) 589 | 590 | # 9. Run safety checker 591 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 592 | 593 | # 10. Convert to PIL 594 | if output_type == "pil": 595 | image = self.numpy_to_pil(image) 596 | 597 | if not return_dict: 598 | return (image, has_nsfw_concept) 599 | 600 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 601 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | from typing import List 3 | 4 | import pyrallis 5 | import torch 6 | from PIL import Image 7 | 8 | from config import RunConfig 9 | from pipeline_attend_and_excite import AttendAndExcitePipeline 10 | from utils import ptp_utils, vis_utils 11 | from utils.ptp_utils import AttentionStore 12 | 13 | import warnings 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | 17 | def load_model(config: RunConfig): 18 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 19 | 20 | if config.sd_2_1: 21 | stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base" 22 | else: 23 | stable_diffusion_version = "CompVis/stable-diffusion-v1-4" 24 | stable = AttendAndExcitePipeline.from_pretrained(stable_diffusion_version).to(device) 25 | return stable 26 | 27 | 28 | def get_indices_to_alter(stable, prompt: str) -> List[int]: 29 | token_idx_to_word = {idx: stable.tokenizer.decode(t) 30 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids']) 31 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1} 32 | pprint.pprint(token_idx_to_word) 33 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to " 34 | "alter (e.g., 2,5): ") 35 | token_indices = [int(i) for i in token_indices.split(",")] 36 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}") 37 | return token_indices 38 | 39 | 40 | def run_on_prompt(prompt: List[str], 41 | model: AttendAndExcitePipeline, 42 | controller: AttentionStore, 43 | token_indices: List[int], 44 | seed: torch.Generator, 45 | config: RunConfig) -> Image.Image: 46 | if controller is not None: 47 | ptp_utils.register_attention_control(model, controller) 48 | outputs = model(prompt=prompt, 49 | attention_store=controller, 50 | indices_to_alter=token_indices, 51 | attention_res=config.attention_res, 52 | guidance_scale=config.guidance_scale, 53 | generator=seed, 54 | num_inference_steps=config.n_inference_steps, 55 | max_iter_to_alter=config.max_iter_to_alter, 56 | run_standard_sd=config.run_standard_sd, 57 | thresholds=config.thresholds, 58 | scale_factor=config.scale_factor, 59 | scale_range=config.scale_range, 60 | smooth_attentions=config.smooth_attentions, 61 | sigma=config.sigma, 62 | kernel_size=config.kernel_size, 63 | sd_2_1=config.sd_2_1) 64 | image = outputs.images[0] 65 | return image 66 | 67 | 68 | @pyrallis.wrap() 69 | def main(config: RunConfig): 70 | stable = load_model(config) 71 | token_indices = get_indices_to_alter(stable, config.prompt) if config.token_indices is None else config.token_indices 72 | 73 | images = [] 74 | for seed in config.seeds: 75 | print(f"Seed: {seed}") 76 | g = torch.Generator('cuda').manual_seed(seed) 77 | controller = AttentionStore() 78 | image = run_on_prompt(prompt=config.prompt, 79 | model=stable, 80 | controller=controller, 81 | token_indices=token_indices, 82 | seed=g, 83 | config=config) 84 | prompt_output_path = config.output_path / config.prompt 85 | prompt_output_path.mkdir(exist_ok=True, parents=True) 86 | image.save(prompt_output_path / f'{seed}.png') 87 | images.append(image) 88 | 89 | # save a grid of results across all seeds 90 | joined_image = vis_utils.get_image_grid(images) 91 | joined_image.save(config.output_path / f'{config.prompt}.png') 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/Attend-and-Excite/163efdfd341bf3590df3c0c2b582935fbc8e8343/utils/__init__.py -------------------------------------------------------------------------------- /utils/gaussian_smoothing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class GaussianSmoothing(nn.Module): 9 | """ 10 | Apply gaussian smoothing on a 11 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 12 | in the input using a depthwise convolution. 13 | Arguments: 14 | channels (int, sequence): Number of channels of the input tensors. Output will 15 | have this number of channels as well. 16 | kernel_size (int, sequence): Size of the gaussian kernel. 17 | sigma (float, sequence): Standard deviation of the gaussian kernel. 18 | dim (int, optional): The number of dimensions of the data. 19 | Default value is 2 (spatial). 20 | """ 21 | def __init__(self, channels, kernel_size, sigma, dim=2): 22 | super(GaussianSmoothing, self).__init__() 23 | if isinstance(kernel_size, numbers.Number): 24 | kernel_size = [kernel_size] * dim 25 | if isinstance(sigma, numbers.Number): 26 | sigma = [sigma] * dim 27 | 28 | # The gaussian kernel is the product of the 29 | # gaussian function of each dimension. 30 | kernel = 1 31 | meshgrids = torch.meshgrid( 32 | [ 33 | torch.arange(size, dtype=torch.float32) 34 | for size in kernel_size 35 | ] 36 | ) 37 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 38 | mean = (size - 1) / 2 39 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 40 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 41 | 42 | # Make sure sum of values in gaussian kernel equals 1. 43 | kernel = kernel / torch.sum(kernel) 44 | 45 | # Reshape to depthwise convolutional weight 46 | kernel = kernel.view(1, 1, *kernel.size()) 47 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 48 | 49 | self.register_buffer('weight', kernel) 50 | self.groups = channels 51 | 52 | if dim == 1: 53 | self.conv = F.conv1d 54 | elif dim == 2: 55 | self.conv = F.conv2d 56 | elif dim == 3: 57 | self.conv = F.conv3d 58 | else: 59 | raise RuntimeError( 60 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 61 | ) 62 | 63 | def forward(self, input): 64 | """ 65 | Apply gaussian filter to input. 66 | Arguments: 67 | input (torch.Tensor): Input to apply gaussian filter on. 68 | Returns: 69 | filtered (torch.Tensor): Filtered output. 70 | """ 71 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) 72 | 73 | 74 | class AverageSmoothing(nn.Module): 75 | """ 76 | Apply average smoothing on a 77 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 78 | in the input using a depthwise convolution. 79 | Arguments: 80 | channels (int, sequence): Number of channels of the input tensors. Output will 81 | have this number of channels as well. 82 | kernel_size (int, sequence): Size of the average kernel. 83 | sigma (float, sequence): Standard deviation of the rage kernel. 84 | dim (int, optional): The number of dimensions of the data. 85 | Default value is 2 (spatial). 86 | """ 87 | def __init__(self, channels, kernel_size, dim=2): 88 | super(AverageSmoothing, self).__init__() 89 | 90 | # Make sure sum of values in gaussian kernel equals 1. 91 | kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size) 92 | 93 | # Reshape to depthwise convolutional weight 94 | kernel = kernel.view(1, 1, *kernel.size()) 95 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 96 | 97 | self.register_buffer('weight', kernel) 98 | self.groups = channels 99 | 100 | if dim == 1: 101 | self.conv = F.conv1d 102 | elif dim == 2: 103 | self.conv = F.conv2d 104 | elif dim == 3: 105 | self.conv = F.conv3d 106 | else: 107 | raise RuntimeError( 108 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 109 | ) 110 | 111 | def forward(self, input): 112 | """ 113 | Apply average filter to input. 114 | Arguments: 115 | input (torch.Tensor): Input to apply average filter on. 116 | Returns: 117 | filtered (torch.Tensor): Filtered output. 118 | """ 119 | return self.conv(input, weight=self.weight, groups=self.groups) 120 | -------------------------------------------------------------------------------- /utils/ptp_utils.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from IPython.display import display 7 | from PIL import Image 8 | from typing import Union, Tuple, List 9 | 10 | from diffusers.models.cross_attention import CrossAttention 11 | 12 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: 13 | h, w, c = image.shape 14 | offset = int(h * .2) 15 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 16 | font = cv2.FONT_HERSHEY_SIMPLEX 17 | img[:h] = image 18 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 19 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 20 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 21 | return img 22 | 23 | 24 | def view_images(images: Union[np.ndarray, List], 25 | num_rows: int = 1, 26 | offset_ratio: float = 0.02, 27 | display_image: bool = True) -> Image.Image: 28 | """ Displays a list of images in a grid. """ 29 | if type(images) is list: 30 | num_empty = len(images) % num_rows 31 | elif images.ndim == 4: 32 | num_empty = images.shape[0] % num_rows 33 | else: 34 | images = [images] 35 | num_empty = 0 36 | 37 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 38 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 39 | num_items = len(images) 40 | 41 | h, w, c = images[0].shape 42 | offset = int(h * offset_ratio) 43 | num_cols = num_items // num_rows 44 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 45 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 46 | for i in range(num_rows): 47 | for j in range(num_cols): 48 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 49 | i * num_cols + j] 50 | 51 | pil_img = Image.fromarray(image_) 52 | if display_image: 53 | display(pil_img) 54 | return pil_img 55 | 56 | 57 | class AttendExciteCrossAttnProcessor: 58 | 59 | def __init__(self, attnstore, place_in_unet): 60 | super().__init__() 61 | self.attnstore = attnstore 62 | self.place_in_unet = place_in_unet 63 | 64 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 65 | batch_size, sequence_length, _ = hidden_states.shape 66 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) 67 | 68 | query = attn.to_q(hidden_states) 69 | 70 | is_cross = encoder_hidden_states is not None 71 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 72 | key = attn.to_k(encoder_hidden_states) 73 | value = attn.to_v(encoder_hidden_states) 74 | 75 | query = attn.head_to_batch_dim(query) 76 | key = attn.head_to_batch_dim(key) 77 | value = attn.head_to_batch_dim(value) 78 | 79 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 80 | 81 | self.attnstore(attention_probs, is_cross, self.place_in_unet) 82 | 83 | hidden_states = torch.bmm(attention_probs, value) 84 | hidden_states = attn.batch_to_head_dim(hidden_states) 85 | 86 | # linear proj 87 | hidden_states = attn.to_out[0](hidden_states) 88 | # dropout 89 | hidden_states = attn.to_out[1](hidden_states) 90 | 91 | return hidden_states 92 | 93 | 94 | def register_attention_control(model, controller): 95 | 96 | attn_procs = {} 97 | cross_att_count = 0 98 | for name in model.unet.attn_processors.keys(): 99 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 100 | if name.startswith("mid_block"): 101 | hidden_size = model.unet.config.block_out_channels[-1] 102 | place_in_unet = "mid" 103 | elif name.startswith("up_blocks"): 104 | block_id = int(name[len("up_blocks.")]) 105 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 106 | place_in_unet = "up" 107 | elif name.startswith("down_blocks"): 108 | block_id = int(name[len("down_blocks.")]) 109 | hidden_size = model.unet.config.block_out_channels[block_id] 110 | place_in_unet = "down" 111 | else: 112 | continue 113 | 114 | cross_att_count += 1 115 | attn_procs[name] = AttendExciteCrossAttnProcessor( 116 | attnstore=controller, place_in_unet=place_in_unet 117 | ) 118 | 119 | model.unet.set_attn_processor(attn_procs) 120 | controller.num_att_layers = cross_att_count 121 | 122 | 123 | class AttentionControl(abc.ABC): 124 | 125 | def step_callback(self, x_t): 126 | return x_t 127 | 128 | def between_steps(self): 129 | return 130 | 131 | @property 132 | def num_uncond_att_layers(self): 133 | return 0 134 | 135 | @abc.abstractmethod 136 | def forward(self, attn, is_cross: bool, place_in_unet: str): 137 | raise NotImplementedError 138 | 139 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 140 | if self.cur_att_layer >= self.num_uncond_att_layers: 141 | self.forward(attn, is_cross, place_in_unet) 142 | self.cur_att_layer += 1 143 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 144 | self.cur_att_layer = 0 145 | self.cur_step += 1 146 | self.between_steps() 147 | 148 | def reset(self): 149 | self.cur_step = 0 150 | self.cur_att_layer = 0 151 | 152 | def __init__(self): 153 | self.cur_step = 0 154 | self.num_att_layers = -1 155 | self.cur_att_layer = 0 156 | 157 | 158 | class EmptyControl(AttentionControl): 159 | 160 | def forward(self, attn, is_cross: bool, place_in_unet: str): 161 | return attn 162 | 163 | 164 | class AttentionStore(AttentionControl): 165 | 166 | @staticmethod 167 | def get_empty_store(): 168 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 169 | "down_self": [], "mid_self": [], "up_self": []} 170 | 171 | def forward(self, attn, is_cross: bool, place_in_unet: str): 172 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 173 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 174 | self.step_store[key].append(attn) 175 | return attn 176 | 177 | def between_steps(self): 178 | self.attention_store = self.step_store 179 | if self.save_global_store: 180 | with torch.no_grad(): 181 | if len(self.global_store) == 0: 182 | self.global_store = self.step_store 183 | else: 184 | for key in self.global_store: 185 | for i in range(len(self.global_store[key])): 186 | self.global_store[key][i] += self.step_store[key][i].detach() 187 | self.step_store = self.get_empty_store() 188 | self.step_store = self.get_empty_store() 189 | 190 | def get_average_attention(self): 191 | average_attention = self.attention_store 192 | return average_attention 193 | 194 | def get_average_global_attention(self): 195 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in 196 | self.attention_store} 197 | return average_attention 198 | 199 | def reset(self): 200 | super(AttentionStore, self).reset() 201 | self.step_store = self.get_empty_store() 202 | self.attention_store = {} 203 | self.global_store = {} 204 | 205 | def __init__(self, save_global_store=False): 206 | ''' 207 | Initialize an empty AttentionStore 208 | :param step_index: used to visualize only a specific step in the diffusion process 209 | ''' 210 | super(AttentionStore, self).__init__() 211 | self.save_global_store = save_global_store 212 | self.step_store = self.get_empty_store() 213 | self.attention_store = {} 214 | self.global_store = {} 215 | self.curr_step_index = 0 216 | 217 | 218 | def aggregate_attention(attention_store: AttentionStore, 219 | res: int, 220 | from_where: List[str], 221 | is_cross: bool, 222 | select: int) -> torch.Tensor: 223 | """ Aggregates the attention across the different layers and heads at the specified resolution. """ 224 | out = [] 225 | attention_maps = attention_store.get_average_attention() 226 | num_pixels = res ** 2 227 | for location in from_where: 228 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 229 | if item.shape[1] == num_pixels: 230 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] 231 | out.append(cross_maps) 232 | out = torch.cat(out, dim=0) 233 | out = out.sum(0) / out.shape[0] 234 | return out 235 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from PIL import Image 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from utils import ptp_utils 9 | from utils.ptp_utils import AttentionStore, aggregate_attention 10 | 11 | 12 | def show_cross_attention(prompt: str, 13 | attention_store: AttentionStore, 14 | tokenizer, 15 | indices_to_alter: List[int], 16 | res: int, 17 | from_where: List[str], 18 | select: int = 0, 19 | orig_image=None): 20 | tokens = tokenizer.encode(prompt) 21 | decoder = tokenizer.decode 22 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select).detach().cpu() 23 | images = [] 24 | 25 | # show spatial attention for indices of tokens to strengthen 26 | for i in range(len(tokens)): 27 | image = attention_maps[:, :, i] 28 | if i in indices_to_alter: 29 | image = show_image_relevance(image, orig_image) 30 | image = image.astype(np.uint8) 31 | image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2))) 32 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 33 | images.append(image) 34 | 35 | ptp_utils.view_images(np.stack(images, axis=0)) 36 | 37 | 38 | def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16): 39 | # create heatmap from mask on image 40 | def show_cam_on_image(img, mask): 41 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 42 | heatmap = np.float32(heatmap) / 255 43 | cam = heatmap + np.float32(img) 44 | cam = cam / np.max(cam) 45 | return cam 46 | 47 | image = image.resize((relevnace_res ** 2, relevnace_res ** 2)) 48 | image = np.array(image) 49 | 50 | image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1]) 51 | image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu 52 | image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear') 53 | image_relevance = image_relevance.cpu() # send it back to cpu 54 | image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) 55 | image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2) 56 | image = (image - image.min()) / (image.max() - image.min()) 57 | vis = show_cam_on_image(image, image_relevance) 58 | vis = np.uint8(255 * vis) 59 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) 60 | return vis 61 | 62 | 63 | def get_image_grid(images: List[Image.Image]) -> Image: 64 | num_images = len(images) 65 | cols = int(math.ceil(math.sqrt(num_images))) 66 | rows = int(math.ceil(num_images / cols)) 67 | width, height = images[0].size 68 | grid_image = Image.new('RGB', (cols * width, rows * height)) 69 | for i, img in enumerate(images): 70 | x = i % cols 71 | y = i // cols 72 | grid_image.paste(img, (x * width, y * height)) 73 | return grid_image 74 | --------------------------------------------------------------------------------