├── .gitignore ├── README.md ├── docs ├── generated.jpg ├── real.jpg ├── seg.jpg └── teaser.jpg ├── gradio_app.py ├── lpm_env.yml ├── main.py ├── real_images ├── lamp_simple.png └── rinon_cat.jpg ├── requirements.txt ├── run_segmentation.py ├── src ├── attention_based_segmentation.py ├── attention_utils.py ├── diffusion_model_wrapper.py ├── null_text_inversion.py ├── prompt_mixing.py ├── prompt_to_prompt_controllers.py ├── prompt_utils.py └── seq_aligner.py ├── style.css └── vocab.json /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | results/ 3 | .vscode/ 4 | segmentation_results/ 5 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Localizing Object-level Shape Variations with Text-to-Image Diffusion Models (ICCV 2023) 2 | 3 | > **Or Patashnik, Daniel Garibi, Idan Azuri, Hadar Averbuch-Elor, Daniel Cohen-Or** 4 | > 5 | > Text-to-image models give rise to workflows which often begin with an exploration step, where users sift through a large collection of generated images. The global nature of the text-to-image generation process prevents users from narrowing their exploration to a particular object in the image. In this paper, we present a technique to generate a collection of images that depicts variations in the shape of a specific object, enabling an object-level shape exploration process. Creating plausible variations is challenging as it requires control over the shape of the generated object while respecting its semantics. A particular challenge when generating object variations is accurately localizing the manipulation applied over the object's shape. We introduce a prompt-mixing technique that switches between prompts along the denoising process to attain a variety of shape choices. To localize the image-space operation, we present two techniques that use the self-attention layers in conjunction with the cross-attention layers. Moreover, we show that these localization techniques are general and effective beyond the scope of generating object variations. Extensive results and comparisons demonstrate the effectiveness of our method in generating object variations, and the competence of our localization techniques. 6 | 7 | 8 | 9 | 10 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/orpatashnik/local-prompt-mixing) 11 | 12 |

13 | 14 |

15 | 16 | ## Description 17 | Official implementation of our Localizing Object-level Shape Variations with Text-to-Image Diffusion Models paper. 18 | 19 | ## Setup 20 | 21 | ### Environment 22 | Our code builds on the requirement of the official [Stable Diffusion repository](https://github.com/CompVis/stable-diffusion). To set up the environment, please run: 23 | 24 | ``` 25 | conda env create -f lpm_env.yml 26 | conda activate lpm 27 | ``` 28 | 29 | Then, please run in python: 30 | ```python 31 | import nltk 32 | nltk.download('punkt') 33 | nltk.download('averaged_perceptron_tagger') 34 | ``` 35 | 36 | This project has a gradio [demo](https://huggingface.co/spaces/orpatashnik/local-prompt-mixing) deployed in HuggingFace. 37 | To run the demo locally, run the following: 38 | ```shell 39 | gradio gradio_app.py 40 | ``` 41 | Then, you can connect to the local demo by browsing to `http://localhost:7860/`. 42 | 43 | ## Prompt Mix & Match Usage 44 | 45 | ### Generated Images 46 | 47 |

48 | 49 |
50 | Example generations by Stable Diffusion with variations outputted by Prompt Mix & Match. 51 |

52 | 53 | 54 | To generate an image, you can simply run the `main.py` script. For example, 55 | ``` 56 | python main.py --seed 48 --prompt "hamster eating {word} on the beach" --object_of_interest "watermelon" --background_nouns=["beach","hamster"] 57 | ``` 58 | Notes: 59 | 60 | - To choose the amount of required variations, specify: `--number_of_variations 20`. 61 | - You may use your own proxy words instead of the auto-generated words. For example `--proxy_words=["pizza","ball","cube"]`. 62 | - In order to change the shape inerval ($T_3$ and $T_2$ in the paper), specify: `--start_prompt_range 5 --end_prompt_range 15`. 63 | - You may use self-attention localization, for example `--objects_to_preserve=["hamster"]`. 64 | - You may also remove the object of interest from the self attention mask by specifing `--remove_obj_from_self_mask True` (this flag is `True` by default) 65 | 66 | All generated images will be saved to the path `"{exp_dir}/{prompt}/seed={seed}_{exp_name}/"`: 67 | ``` 68 | {exp_dir}/ 69 | |-- {prompt}/ 70 | | |-- seed={seed}_{exp_name}/ 71 | | |-- {object_of_interest}.jpg 72 | | |-- {proxy_words[0]}.jpg 73 | | |-- {proxy_words[1]}.jpg 74 | | |-- ... 75 | | |-- grid.jpg 76 | | |-- opt.json 77 | | |-- seed={seed}_{exp_name}/ 78 | | |-- {object_of_interest}.jpg 79 | | |-- {proxy_words[0]}.jpg 80 | | |-- {proxy_words[1]}.jpg 81 | | |-- ... 82 | | |-- grid.jpg 83 | | |-- opt.json 84 | ... 85 | ``` 86 | The default values are `--exp_dir "results" --exp_name ""`. 87 | ### Real Images 88 | 89 |

90 | 91 |
92 | Example real image with variations outputted by Prompt Mix & Match. 93 |

94 | 95 | To generate an image, you can simply run the `main.py` script. For example, 96 | ``` 97 | python main.py --real_image_path "real_images/lamp_simple.png" --prompt "A table below a {word}" --object_of_interest "lamp" --objects_to_preserve=["table"] --background_nouns=["table"] 98 | ``` 99 | 100 | All generated images will be saved in the same format as for generated image. 101 | 102 | ## Self-Segmentation 103 | 104 |

105 | 106 |
107 | Example segmantation of a real image. 108 |

109 | 110 | To get segmentation of an image, you can simply run the `run_segmentation.py` script. 111 | The paraments of the segmentation located inside the script at the `SegmentationConfig` class: 112 | - To use a real image, specify its path at the attribute `real_image_path`. 113 | - You may change the number of segments with the `num_segments` param. 114 | 115 | The outputs will be saved to the path `"{exp_path}/"`: 116 | ``` 117 | {exp_path}/ 118 | | |-- real_image.jpg 119 | | |-- image_enc.jpg 120 | | |-- image_rec.jpg 121 | | |-- segmentation.jpg 122 | ``` 123 | - `real_image.jpg` - Original image. 124 | - `image_enc.jpg` - Reconstration of the image by stable autoencoder only. 125 | - `image_rec.jpg` - Reconstration of the image by stable diffusion full pipeline. 126 | - `segmentation.jpg` - Segmentation output. 127 | 128 | 129 | ## Acknowledgements 130 | This code is builds on the code from the [diffusers](https://github.com/huggingface/diffusers) library as well as the [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/) codebase. 131 | 132 | ## Citation 133 | 134 | If you use this code for your research, please cite our paper: 135 | 136 | ``` 137 | @InProceedings{patashnik2023localizing, 138 | author = {Patashnik, Or and Garibi, Daniel and Azuri, Idan and Averbuch-Elor, Hadar and Cohen-Or, Daniel}, 139 | title = {Localizing Object-level Shape Variations with Text-to-Image Diffusion Models}, 140 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 141 | year = {2023} 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /docs/generated.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/docs/generated.jpg -------------------------------------------------------------------------------- /docs/real.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/docs/real.jpg -------------------------------------------------------------------------------- /docs/seg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/docs/seg.jpg -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/docs/teaser.jpg -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import gradio as gr 4 | import nltk 5 | import numpy as np 6 | from PIL import Image 7 | 8 | nltk.download('punkt') 9 | nltk.download('averaged_perceptron_tagger') 10 | 11 | from main import LPMConfig, main, setup 12 | 13 | DESCRIPTION = '''# Localizing Object-level Shape Variations with Text-to-Image Diffusion Models 14 | This is a demo for our ''Localizing Object-level Shape Variations with Text-to-Image Diffusion Models'' [paper](https://arxiv.org/abs/2303.11306). 15 | We introduce a method that generates object-level shape variation for a given image. 16 | This demo supports both generated images and real images. To modify a real image, please upload it to the input image block and provide a prompt that describes its contents. 17 | 18 | ''' 19 | 20 | stable, stable_config = setup(LPMConfig()) 21 | 22 | def main_pipeline( 23 | prompt: str, 24 | object_of_interest: str, 25 | proxy_words: str, 26 | number_of_variations: int, 27 | start_prompt_range: int, 28 | end_prompt_range: int, 29 | objects_to_preserve: str, 30 | background_nouns: str, 31 | seed: int, 32 | input_image: str): 33 | prompt = prompt.replace(object_of_interest, '{word}') 34 | proxy_words = proxy_words.split(',') if proxy_words != '' else [] 35 | objects_to_preserve = objects_to_preserve.split(',') if objects_to_preserve != '' else [] 36 | background_nouns = background_nouns.split(',') if background_nouns != '' else [] 37 | args = LPMConfig( 38 | seed=seed, 39 | prompt=prompt, 40 | object_of_interest=object_of_interest, 41 | proxy_words=proxy_words, 42 | number_of_variations=number_of_variations, 43 | start_prompt_range=start_prompt_range, 44 | end_prompt_range=end_prompt_range, 45 | objects_to_preserve=objects_to_preserve, 46 | background_nouns=background_nouns, 47 | real_image_path="" if input_image is None else input_image 48 | ) 49 | 50 | result_images, result_proxy_words = main(stable, stable_config, args) 51 | result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images] 52 | result_images = [(im * 255).astype(np.uint8) for im in result_images] 53 | result_images = [Image.fromarray(im) for im in result_images] 54 | 55 | return result_images, ",".join(result_proxy_words) 56 | 57 | 58 | with gr.Blocks(css='style.css') as demo: 59 | gr.Markdown(DESCRIPTION) 60 | 61 | gr.HTML( 62 | ''' 63 | Duplicate SpaceDuplicate the Space to run privately without waiting in queue''') 64 | 65 | with gr.Row(): 66 | with gr.Column(): 67 | input_image = gr.Image( 68 | label="Input image (optional)", 69 | type="filepath" 70 | ) 71 | prompt = gr.Text( 72 | label='Prompt', 73 | max_lines=1, 74 | placeholder='A table below a lamp', 75 | ) 76 | object_of_interest = gr.Text( 77 | label='Object of interest', 78 | max_lines=1, 79 | placeholder='lamp', 80 | ) 81 | proxy_words = gr.Text( 82 | label='Proxy words - words used to obtain variations (a comma-separated list of words, can leave empty)', 83 | max_lines=1, 84 | placeholder='' 85 | ) 86 | number_of_variations = gr.Slider( 87 | label='Number of variations (used only for automatic proxy-words)', 88 | minimum=2, 89 | maximum=30, 90 | value=7, 91 | step=1 92 | ) 93 | start_prompt_range = gr.Slider( 94 | label='Number of steps before starting shape interval', 95 | minimum=0, 96 | maximum=50, 97 | value=7, 98 | step=1 99 | ) 100 | end_prompt_range = gr.Slider( 101 | label='Number of steps before ending shape interval', 102 | minimum=1, 103 | maximum=50, 104 | value=17, 105 | step=1 106 | ) 107 | objects_to_preserve = gr.Text( 108 | label='Words corresponding to objects to preserve (a comma-separated list of words, can leave empty)', 109 | max_lines=1, 110 | placeholder='table', 111 | ) 112 | background_nouns = gr.Text( 113 | label='Words corresponding to objects that should be copied from original image (a comma-separated list of words, can leave empty)', 114 | max_lines=1, 115 | placeholder='', 116 | ) 117 | seed = gr.Slider( 118 | label='Seed', 119 | minimum=1, 120 | maximum=100000, 121 | value=0, 122 | step=1 123 | ) 124 | 125 | run_button = gr.Button('Generate') 126 | with gr.Column(): 127 | result = gr.Gallery(label='Result').style(grid=4) 128 | proxy_words_result = gr.Text(label='Used proxy words') 129 | 130 | examples = [ 131 | [ 132 | "hamster eating watermelon on the beach", 133 | "watermelon", 134 | "", 135 | 7, 136 | 6, 137 | 16, 138 | "", 139 | "hamster,beach", 140 | 48, 141 | None 142 | ], 143 | [ 144 | "A decorated lamp in the livingroom", 145 | "lamp", 146 | "", 147 | 7, 148 | 4, 149 | 14, 150 | "livingroom", 151 | "", 152 | 42, 153 | None 154 | ], 155 | [ 156 | "a snake in the field eats an apple", 157 | "snake", 158 | "", 159 | 7, 160 | 7, 161 | 17, 162 | "apple", 163 | "apple,field", 164 | 10, 165 | None 166 | ] 167 | ] 168 | 169 | gr.Examples(examples=examples, 170 | inputs=[ 171 | prompt, 172 | object_of_interest, 173 | proxy_words, 174 | number_of_variations, 175 | start_prompt_range, 176 | end_prompt_range, 177 | objects_to_preserve, 178 | background_nouns, 179 | seed, 180 | input_image 181 | ], 182 | outputs=[ 183 | result, 184 | proxy_words_result 185 | ], 186 | fn=main_pipeline, 187 | cache_examples=False) 188 | 189 | 190 | inputs = [ 191 | prompt, 192 | object_of_interest, 193 | proxy_words, 194 | number_of_variations, 195 | start_prompt_range, 196 | end_prompt_range, 197 | objects_to_preserve, 198 | background_nouns, 199 | seed, 200 | input_image 201 | ] 202 | outputs = [ 203 | result, 204 | proxy_words_result 205 | ] 206 | run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs) 207 | 208 | demo.queue(max_size=50).launch(share=False) -------------------------------------------------------------------------------- /lpm_env.yml: -------------------------------------------------------------------------------- 1 | name: lpm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - anaconda 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_gnu 11 | - anyio=3.6.2=pyhd8ed1ab_0 12 | - argon2-cffi=21.3.0=pyhd8ed1ab_0 13 | - argon2-cffi-bindings=21.2.0=py310h5764c6d_3 14 | - asttokens=2.2.1=pyhd8ed1ab_0 15 | - attrs=22.2.0=pyh71513ae_0 16 | - backcall=0.2.0=pyh9f0ad1d_0 17 | - backports=1.0=pyhd8ed1ab_3 18 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 19 | - beautifulsoup4=4.11.2=pyha770c72_0 20 | - blas=1.0=mkl 21 | - bleach=6.0.0=pyhd8ed1ab_0 22 | - brotli=1.0.9=h166bdaf_7 23 | - brotli-bin=1.0.9=h166bdaf_7 24 | - brotlipy=0.7.0=py310h7f8727e_1002 25 | - bzip2=1.0.8=h7b6447c_0 26 | - ca-certificates=2022.12.7=ha878542_0 27 | - certifi=2022.12.7=pyhd8ed1ab_0 28 | - cffi=1.15.1=py310h5eee18b_3 29 | - chardet=4.0.0=py310h06a4308_1003 30 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 31 | - comm=0.1.2=pyhd8ed1ab_0 32 | - contourpy=1.0.5=py310hdb19cb5_0 33 | - cryptography=38.0.4=py310h9ce1e76_0 34 | - cuda=11.7.1=0 35 | - cuda-cccl=11.7.91=0 36 | - cuda-command-line-tools=11.7.1=0 37 | - cuda-compiler=11.7.1=0 38 | - cuda-cudart=11.7.99=0 39 | - cuda-cudart-dev=11.7.99=0 40 | - cuda-cuobjdump=11.7.91=0 41 | - cuda-cupti=11.7.101=0 42 | - cuda-cuxxfilt=11.7.91=0 43 | - cuda-demo-suite=12.0.76=0 44 | - cuda-documentation=12.0.76=0 45 | - cuda-driver-dev=11.7.99=0 46 | - cuda-gdb=12.0.90=0 47 | - cuda-libraries=11.7.1=0 48 | - cuda-libraries-dev=11.7.1=0 49 | - cuda-memcheck=11.8.86=0 50 | - cuda-nsight=12.0.78=0 51 | - cuda-nsight-compute=12.0.0=0 52 | - cuda-nvcc=11.7.99=0 53 | - cuda-nvdisasm=12.0.76=0 54 | - cuda-nvml-dev=11.7.91=0 55 | - cuda-nvprof=12.0.90=0 56 | - cuda-nvprune=11.7.91=0 57 | - cuda-nvrtc=11.7.99=0 58 | - cuda-nvrtc-dev=11.7.99=0 59 | - cuda-nvtx=11.7.91=0 60 | - cuda-nvvp=12.0.90=0 61 | - cuda-runtime=11.7.1=0 62 | - cuda-sanitizer-api=12.0.90=0 63 | - cuda-toolkit=11.7.1=0 64 | - cuda-tools=11.7.1=0 65 | - cuda-visual-tools=11.7.1=0 66 | - cycler=0.11.0=pyhd8ed1ab_0 67 | - dbus=1.13.18=hb2f20db_0 68 | - debugpy=1.5.1=py310h295c915_0 69 | - decorator=5.1.1=pyhd8ed1ab_0 70 | - defusedxml=0.7.1=pyhd8ed1ab_0 71 | - entrypoints=0.4=pyhd8ed1ab_0 72 | - executing=1.2.0=pyhd8ed1ab_0 73 | - expat=2.2.10=h9c3ff4c_0 74 | - ffmpeg=4.3=hf484d3e_0 75 | - flit-core=3.6.0=pyhd3eb1b0_0 76 | - fontconfig=2.14.1=hef1e5e3_0 77 | - fonttools=4.25.0=pyhd3eb1b0_0 78 | - freetype=2.12.1=h4a9f257_0 79 | - gds-tools=1.5.0.59=0 80 | - giflib=5.2.1=h7b6447c_0 81 | - glib=2.69.1=he621ea3_2 82 | - gmp=6.2.1=h295c915_3 83 | - gnutls=3.6.15=he1e5248_0 84 | - gst-plugins-base=1.14.0=h8213a91_2 85 | - gstreamer=1.14.0=h28cd5cc_2 86 | - icu=58.2=hf484d3e_1000 87 | - idna=3.4=py310h06a4308_0 88 | - importlib-metadata=6.0.0=pyha770c72_0 89 | - importlib_resources=5.10.2=pyhd8ed1ab_0 90 | - intel-openmp=2021.4.0=h06a4308_3561 91 | - ipykernel=6.19.2=py310h2f386ee_0 92 | - ipython=8.8.0=py310h06a4308_0 93 | - ipython_genutils=0.2.0=py_1 94 | - jedi=0.18.2=pyhd8ed1ab_0 95 | - jinja2=3.1.2=pyhd8ed1ab_1 96 | - jpeg=9e=h7f8727e_0 97 | - jsonschema=4.17.3=pyhd8ed1ab_0 98 | - jupyter_client=7.3.4=pyhd8ed1ab_0 99 | - jupyter_core=4.12.0=py310hff52083_0 100 | - jupyter_server=1.23.5=pyhd8ed1ab_0 101 | - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 102 | - keyutils=1.6.1=h166bdaf_0 103 | - kiwisolver=1.4.4=py310h6a678d5_0 104 | - krb5=1.19.3=h3790be6_0 105 | - lame=3.100=h7b6447c_0 106 | - lcms2=2.12=h3be6417_0 107 | - ld_impl_linux-64=2.38=h1181459_1 108 | - lerc=3.0=h295c915_0 109 | - libbrotlicommon=1.0.9=h166bdaf_7 110 | - libbrotlidec=1.0.9=h166bdaf_7 111 | - libbrotlienc=1.0.9=h166bdaf_7 112 | - libclang=10.0.1=default_hb85057a_2 113 | - libcublas=11.10.3.66=0 114 | - libcublas-dev=11.10.3.66=0 115 | - libcufft=10.7.2.124=h4fbf590_0 116 | - libcufft-dev=10.7.2.124=h98a8f43_0 117 | - libcufile=1.5.0.59=0 118 | - libcufile-dev=1.5.0.59=0 119 | - libcurand=10.3.1.50=0 120 | - libcurand-dev=10.3.1.50=0 121 | - libcusolver=11.4.0.1=0 122 | - libcusolver-dev=11.4.0.1=0 123 | - libcusparse=11.7.4.91=0 124 | - libcusparse-dev=11.7.4.91=0 125 | - libdeflate=1.8=h7f8727e_5 126 | - libedit=3.1.20191231=he28a2e2_2 127 | - libevent=2.1.12=h8f2d780_0 128 | - libffi=3.4.2=h6a678d5_6 129 | - libgcc-ng=12.2.0=h65d4601_19 130 | - libgomp=12.2.0=h65d4601_19 131 | - libiconv=1.16=h7f8727e_2 132 | - libidn2=2.3.2=h7f8727e_0 133 | - libllvm10=10.0.1=he513fc3_3 134 | - libnpp=11.7.4.75=0 135 | - libnpp-dev=11.7.4.75=0 136 | - libnvjpeg=11.8.0.2=0 137 | - libnvjpeg-dev=11.8.0.2=0 138 | - libpng=1.6.37=hbc83047_0 139 | - libpq=12.9=h16c4e8d_3 140 | - libsodium=1.0.18=h36c2ea0_1 141 | - libstdcxx-ng=12.2.0=h46fd767_19 142 | - libtasn1=4.16.0=h27cfd23_0 143 | - libtiff=4.5.0=hecacb30_0 144 | - libunistring=0.9.10=h27cfd23_0 145 | - libuuid=1.41.5=h5eee18b_0 146 | - libwebp=1.2.4=h11a3e52_0 147 | - libwebp-base=1.2.4=h5eee18b_0 148 | - libxcb=1.15=h7f8727e_0 149 | - libxkbcommon=1.0.3=he3ba5ed_0 150 | - libxml2=2.9.14=h74e7548_0 151 | - libxslt=1.1.35=h4e12654_0 152 | - lz4-c=1.9.4=h6a678d5_0 153 | - markupsafe=2.1.2=py310h1fa729e_0 154 | - matplotlib=3.6.2=py310hff52083_0 155 | - matplotlib-base=3.6.2=py310h945d387_0 156 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 157 | - mistune=2.0.5=pyhd8ed1ab_0 158 | - mkl=2021.4.0=h06a4308_640 159 | - mkl-service=2.4.0=py310h7f8727e_0 160 | - mkl_fft=1.3.1=py310hd6ae3a3_0 161 | - mkl_random=1.2.2=py310h00e6091_0 162 | - munkres=1.1.4=pyh9f0ad1d_0 163 | - nbclassic=0.5.1=pyhd8ed1ab_0 164 | - nbclient=0.7.2=pyhd8ed1ab_0 165 | - nbconvert=7.2.9=pyhd8ed1ab_0 166 | - nbconvert-core=7.2.9=pyhd8ed1ab_0 167 | - nbconvert-pandoc=7.2.9=pyhd8ed1ab_0 168 | - nbformat=5.7.3=pyhd8ed1ab_0 169 | - ncurses=6.3=h5eee18b_3 170 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 171 | - nettle=3.7.3=hbbd107a_1 172 | - notebook=6.5.2=pyha770c72_1 173 | - notebook-shim=0.2.2=pyhd8ed1ab_0 174 | - nsight-compute=2022.4.0.15=0 175 | - nspr=4.33=h295c915_0 176 | - nss=3.74=h0370c37_0 177 | - numpy=1.23.5=py310hd5efca6_0 178 | - numpy-base=1.23.5=py310h8e6c178_0 179 | - openh264=2.1.1=h4ff587b_0 180 | - openssl=1.1.1t=h0b41bf4_0 181 | - packaging=23.0=pyhd8ed1ab_0 182 | - pandoc=2.19.2=ha770c72_0 183 | - pandocfilters=1.5.0=pyhd8ed1ab_0 184 | - parso=0.8.3=pyhd8ed1ab_0 185 | - pcre=8.45=h9c3ff4c_0 186 | - pexpect=4.8.0=pyh1a96a4e_2 187 | - pickleshare=0.7.5=py_1003 188 | - pillow=9.3.0=py310h6a678d5_2 189 | - pip=22.3.1=py310h06a4308_0 190 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0 191 | - ply=3.11=py_1 192 | - prometheus_client=0.16.0=pyhd8ed1ab_0 193 | - prompt-toolkit=3.0.36=pyha770c72_0 194 | - psutil=5.9.4=py310h5764c6d_0 195 | - ptyprocess=0.7.0=pyhd3deb0d_0 196 | - pure_eval=0.2.2=pyhd8ed1ab_0 197 | - pycparser=2.21=pyhd3eb1b0_0 198 | - pygments=2.14.0=pyhd8ed1ab_0 199 | - pyopenssl=22.0.0=pyhd3eb1b0_0 200 | - pyparsing=3.0.9=pyhd8ed1ab_0 201 | - pyqt=5.15.7=py310h6a678d5_1 202 | - pyrsistent=0.19.3=py310h1fa729e_0 203 | - pysocks=1.7.1=py310h06a4308_0 204 | - python=3.10.8=h7a1cb2a_1 205 | - python-dateutil=2.8.2=pyhd8ed1ab_0 206 | - python-fastjsonschema=2.16.2=pyhd8ed1ab_0 207 | - python_abi=3.10=2_cp310 208 | - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0 209 | - pytorch-cuda=11.7=h67b0de4_1 210 | - pytorch-mutex=1.0=cuda 211 | - pyzmq=25.0.0=py310h059b190_0 212 | - qt-main=5.15.2=h327a75a_7 213 | - qt-webengine=5.15.9=hd2b0992_4 214 | - qtwebkit=5.212=h4eab89a_4 215 | - readline=8.2=h5eee18b_0 216 | - requests=2.28.1=py310h06a4308_0 217 | - send2trash=1.8.0=pyhd8ed1ab_0 218 | - setuptools=65.6.3=py310h06a4308_0 219 | - sip=6.6.2=py310h6a678d5_0 220 | - six=1.16.0=pyhd3eb1b0_1 221 | - sniffio=1.3.0=pyhd8ed1ab_0 222 | - soupsieve=2.3.2.post1=pyhd8ed1ab_0 223 | - sqlite=3.40.1=h5082296_0 224 | - stack_data=0.6.2=pyhd8ed1ab_0 225 | - terminado=0.17.1=pyh41d4057_0 226 | - tinycss2=1.2.1=pyhd8ed1ab_0 227 | - tk=8.6.12=h1ccaba5_0 228 | - toml=0.10.2=pyhd8ed1ab_0 229 | - torchaudio=0.13.1=py310_cu117 230 | - torchvision=0.14.1=py310_cu117 231 | - tornado=6.2=py310h5eee18b_0 232 | - traitlets=5.7.1=py310h06a4308_0 233 | - typing_extensions=4.4.0=py310h06a4308_0 234 | - tzdata=2022g=h04d1e81_0 235 | - urllib3=1.26.14=py310h06a4308_0 236 | - wcwidth=0.2.6=pyhd8ed1ab_0 237 | - webencodings=0.5.1=py_1 238 | - websocket-client=1.5.1=pyhd8ed1ab_0 239 | - wheel=0.37.1=pyhd3eb1b0_0 240 | - xz=5.2.8=h5eee18b_0 241 | - zeromq=4.3.4=h9c3ff4c_1 242 | - zipp=3.11.0=py310h06a4308_0 243 | - zlib=1.2.13=h5eee18b_0 244 | - zstd=1.5.2=ha4553b6_0 245 | - pip: 246 | - accelerate==0.18.0 247 | - click==8.1.3 248 | - diffusers==0.10.2 249 | - filelock==3.10.4 250 | - huggingface-hub==0.13.3 251 | - joblib==1.2.0 252 | - mypy-extensions==1.0.0 253 | - nltk==3.8.1 254 | - opencv-python==4.7.0.72 255 | - pyqt5-sip==12.11.0 256 | - pyrallis==0.3.1 257 | - pyyaml==6.0 258 | - regex==2023.3.23 259 | - scikit-learn==1.2.2 260 | - scipy==1.10.1 261 | - threadpoolctl==3.1.0 262 | - tokenizers==0.13.2 263 | - tqdm==4.65.0 264 | - transformers==4.25.1 265 | - typing-inspect==0.8.0 266 | prefix: /home/danielgaribi/miniconda3/envs/lpm 267 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import List 5 | 6 | import pyrallis 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision.transforms import ToTensor 10 | from torchvision.utils import save_image 11 | from tqdm import tqdm 12 | 13 | from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \ 14 | generate_original_image 15 | from src.null_text_inversion import invert_image 16 | from src.prompt_mixing import PromptMixing 17 | from src.prompt_to_prompt_controllers import AttentionStore, AttentionReplace 18 | from src.prompt_utils import get_proxy_prompts 19 | 20 | 21 | def save_args_dict(args, similar_words): 22 | exp_path = os.path.join(args.exp_dir, args.prompt.replace(' ', '-'), f"seed={args.seed}_{args.exp_name}") 23 | os.makedirs(exp_path, exist_ok=True) 24 | 25 | args_dict = vars(args) 26 | args_dict['similar_words'] = similar_words 27 | with open(os.path.join(exp_path, "opt.json"), 'w') as fp: 28 | json.dump(args_dict, fp, sort_keys=True, indent=4) 29 | 30 | return exp_path 31 | 32 | def setup(args): 33 | ldm_stable = get_stable_diffusion_model(args) 34 | ldm_stable_config = get_stable_diffusion_config(args) 35 | return ldm_stable, ldm_stable_config 36 | 37 | 38 | def main(ldm_stable, ldm_stable_config, args): 39 | 40 | similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable) 41 | exp_path = save_args_dict(args, similar_words) 42 | 43 | images = [] 44 | x_t = None 45 | uncond_embeddings = None 46 | 47 | if args.real_image_path != "": 48 | ldm_stable, ldm_stable_config = setup(args) 49 | x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path) 50 | 51 | image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings) 52 | save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg") 53 | save_image(torch.from_numpy(orig_mask).float(), f"{exp_path}/{similar_words[0]}_mask.jpg") 54 | images.append(image[0]) 55 | 56 | object_of_interest_index = args.prompt.split().index('{word}') + 1 57 | pm = PromptMixing(args, object_of_interest_index, average_attention) 58 | 59 | do_other_obj_self_attn_masking = len(args.objects_to_preserve) > 0 and args.end_preserved_obj_self_attn_masking > 0 60 | do_self_or_cross_attn_inject = args.cross_attn_inject_steps != 0.0 or args.self_attn_inject_steps != 0.0 61 | if do_other_obj_self_attn_masking: 62 | print("Do self attn other obj masking") 63 | if do_self_or_cross_attn_inject: 64 | print(f'Do self attn inject for {args.self_attn_inject_steps} steps') 65 | print(f'Do cross attn inject for {args.cross_attn_inject_steps} steps') 66 | 67 | another_prompts_dataloader = DataLoader(another_prompts[1:], batch_size=args.batch_size, shuffle=False) 68 | 69 | for another_prompt_batch in tqdm(another_prompts_dataloader): 70 | batch_size = len(another_prompt_batch["word"]) 71 | batch_prompts = prompts * batch_size 72 | batch_another_prompt = another_prompt_batch["prompt"] 73 | if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking: 74 | batch_prompts.append(prompts[0]) 75 | batch_another_prompt.insert(0, prompts[0]) 76 | 77 | if do_self_or_cross_attn_inject: 78 | controller = AttentionReplace(batch_another_prompt, ldm_stable.tokenizer, ldm_stable.device, 79 | ldm_stable_config["low_resource"], ldm_stable_config["num_diffusion_steps"], 80 | cross_replace_steps=args.cross_attn_inject_steps, 81 | self_replace_steps=args.self_attn_inject_steps) 82 | else: 83 | controller = AttentionStore(ldm_stable_config["low_resource"]) 84 | 85 | diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, prompt_mixing=pm) 86 | with torch.no_grad(): 87 | image, x_t, _, mask = diffusion_model_wrapper.forward(batch_prompts, latent=x_t, other_prompt=batch_another_prompt, 88 | post_background=args.background_post_process, orig_all_latents=orig_all_latents, 89 | orig_mask=orig_mask, uncond_embeddings=uncond_embeddings) 90 | 91 | for i in range(batch_size): 92 | image_index = i + 1 if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking else i 93 | save_image(ToTensor()(image[image_index]), f"{exp_path}/{another_prompt_batch['word'][i]}.jpg") 94 | if mask is not None: 95 | save_image(torch.from_numpy(mask).float(), f"{exp_path}/{another_prompt_batch['word'][i]}_mask.jpg") 96 | images.append(image[image_index]) 97 | 98 | images = [ToTensor()(image) for image in images] 99 | save_image(images, f"{exp_path}/grid.jpg", nrow=min(max([i for i in range(2, 8) if len(images) % i == 0]), 8)) 100 | return images, similar_words 101 | 102 | 103 | @dataclass 104 | class LPMConfig: 105 | 106 | # general config 107 | seed: int = 10 108 | batch_size: int = 1 109 | exp_dir: str = "results" 110 | exp_name: str = "" 111 | display_images: bool = False 112 | gpu_id: int = 0 113 | 114 | # Stable Diffusion config 115 | auth_token: str = "" 116 | low_resource: bool = True 117 | num_diffusion_steps: int = 50 118 | guidance_scale: float = 7.5 119 | max_num_words: int = 77 120 | 121 | # prompt-mixing 122 | prompt: str = "a {word} in the field eats an apple" 123 | object_of_interest: str = "snake" # The object for which we generate variations 124 | proxy_words: List[str] = field(default_factory=lambda :[]) # Leave empty for automatic proxy words 125 | number_of_variations: int = 20 126 | start_prompt_range: int = 7 # Number of steps to begin prompt-mixing 127 | end_prompt_range: int = 17 # Number of steps to finish prompt-mixing 128 | 129 | # attention based shape localization 130 | objects_to_preserve: List[str] = field(default_factory=lambda :[]) # Objects for which apply attention based shape localization 131 | remove_obj_from_self_mask: bool = True # If set to True, removes the object of interest from the self-attention mask 132 | obj_pixels_injection_threshold: float = 0.05 133 | end_preserved_obj_self_attn_masking: int = 40 134 | 135 | # real image 136 | real_image_path: str = "" 137 | 138 | # controllable background preservation 139 | background_post_process: bool = True 140 | background_nouns: List[str] = field(default_factory=lambda :[]) # Objects to take from the original image in addition to the background 141 | num_segments: int = 5 # Number of clusters for the segmentation 142 | background_segment_threshold: float = 0.3 # Threshold for the segments labeling 143 | background_blend_timestep: int = 35 # Number of steps before background blending 144 | 145 | # other 146 | cross_attn_inject_steps: float = 0.0 147 | self_attn_inject_steps: float = 0.0 148 | 149 | 150 | if __name__ == '__main__': 151 | args = pyrallis.parse(config_class=LPMConfig) 152 | 153 | print(args) 154 | 155 | stable, stable_config = setup(args) 156 | main(stable, stable_config, args) 157 | -------------------------------------------------------------------------------- /real_images/lamp_simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/real_images/lamp_simple.png -------------------------------------------------------------------------------- /real_images/rinon_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/real_images/rinon_cat.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.10.2 2 | opencv-python==4.7.0.72 3 | pyrallis==0.3.1 4 | torch==1.13.1 5 | torchvision==0.14.1 6 | transformers==4.25.1 7 | nltk==3.8.1 8 | scipy 9 | scikit-learn 10 | accelerate 11 | gradio -------------------------------------------------------------------------------- /run_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | from torchvision.transforms import ToTensor 6 | from torchvision.utils import save_image 7 | import matplotlib.pyplot as plt 8 | 9 | from src.attention_based_segmentation import Segmentor 10 | from src.diffusion_model_wrapper import get_stable_diffusion_model, get_stable_diffusion_config, DiffusionModelWrapper 11 | from src.null_text_inversion import invert_image 12 | from src.prompt_to_prompt_controllers import AttentionStore 13 | 14 | @dataclass 15 | class SegmentationConfig: 16 | seed: int = 1111 17 | gpu_id: int = 0 18 | real_image_path: str = "real_images/rinon_cat.jpg" 19 | auth_token: str = "" 20 | low_resource: bool = True 21 | num_diffusion_steps: int = 50 22 | guidance_scale: float = 7.5 23 | max_num_words: int = 77 24 | prompt: str = "a cat in a basket" 25 | exp_path: str = "segmentation_results" 26 | 27 | num_segments: int = 5 28 | background_segment_threshold: float = 0.35 29 | 30 | if __name__ == '__main__': 31 | args = SegmentationConfig() 32 | os.makedirs(args.exp_path, exist_ok=True) 33 | ldm_stable = get_stable_diffusion_model(args) 34 | ldm_stable_config = get_stable_diffusion_config(args) 35 | 36 | x_t = None 37 | uncond_embeddings = None 38 | if args.real_image_path != "": 39 | x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, [args.prompt], args.exp_path) 40 | 41 | g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed) 42 | controller = AttentionStore(ldm_stable_config["low_resource"]) 43 | diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu) 44 | image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward([args.prompt], 45 | latent=x_t, 46 | uncond_embeddings=uncond_embeddings) 47 | segmentor = Segmentor(controller, [args.prompt], args.num_segments, args.background_segment_threshold) 48 | clusters = segmentor.cluster() 49 | cluster2noun = segmentor.cluster2noun(clusters) 50 | 51 | save_image(ToTensor()(image[0]), f"{args.exp_path}/image_rec.jpg") 52 | plt.imshow(clusters) 53 | plt.axis('off') 54 | plt.savefig(f"{args.exp_path}/segmentation.jpg", bbox_inches='tight', pad_inches=0) 55 | -------------------------------------------------------------------------------- /src/attention_based_segmentation.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | from sklearn.cluster import KMeans 3 | import numpy as np 4 | 5 | from src.attention_utils import aggregate_attention 6 | 7 | 8 | class Segmentor: 9 | 10 | def __init__(self, controller, prompts, num_segments, background_segment_threshold, res=32, background_nouns=[]): 11 | self.controller = controller 12 | self.prompts = prompts 13 | self.num_segments = num_segments 14 | self.background_segment_threshold = background_segment_threshold 15 | self.resolution = res 16 | self.background_nouns = background_nouns 17 | 18 | self.self_attention = aggregate_attention(controller, res=32, from_where=("up", "down"), prompts=prompts, 19 | is_cross=False, select=len(prompts) - 1) 20 | self.cross_attention = aggregate_attention(controller, res=16, from_where=("up", "down"), prompts=prompts, 21 | is_cross=True, select=len(prompts) - 1) 22 | tokenized_prompt = nltk.word_tokenize(prompts[-1]) 23 | self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt)) if pos[:2] == 'NN'] 24 | 25 | def __call__(self, *args, **kwargs): 26 | clusters = self.cluster() 27 | cluster2noun = self.cluster2noun(clusters) 28 | return cluster2noun 29 | 30 | def cluster(self): 31 | np.random.seed(1) 32 | resolution = self.self_attention.shape[0] 33 | attn = self.self_attention.cpu().numpy().reshape(resolution ** 2, resolution ** 2) 34 | kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(attn) 35 | clusters = kmeans.labels_ 36 | clusters = clusters.reshape(resolution, resolution) 37 | return clusters 38 | 39 | def cluster2noun(self, clusters): 40 | result = {} 41 | nouns_indices = [index for (index, word) in self.nouns] 42 | nouns_maps = self.cross_attention.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]] 43 | normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1) 44 | for i in range(nouns_maps.shape[-1]): 45 | curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1) 46 | normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max() 47 | for c in range(self.num_segments): 48 | cluster_mask = np.zeros_like(clusters) 49 | cluster_mask[clusters == c] = 1 50 | score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] 51 | scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] 52 | result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > self.background_segment_threshold else "BG" 53 | return result 54 | 55 | def get_background_mask(self, obj_token_index): 56 | clusters = self.cluster() 57 | cluster2noun = self.cluster2noun(clusters) 58 | mask = clusters.copy() 59 | obj_segments = [c for c in cluster2noun if cluster2noun[c][0] == obj_token_index - 1] 60 | background_segments = [c for c in cluster2noun if cluster2noun[c] == "BG" or cluster2noun[c][1] in self.background_nouns] 61 | for c in range(self.num_segments): 62 | if c in background_segments and c not in obj_segments: 63 | mask[clusters == c] = 0 64 | else: 65 | mask[clusters == c] = 1 66 | return mask 67 | 68 | -------------------------------------------------------------------------------- /src/attention_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Tuple, List 4 | from cv2 import putText, getTextSize, FONT_HERSHEY_SIMPLEX 5 | # import matplotlib.pyplot as plt 6 | from PIL import Image 7 | 8 | from src.prompt_to_prompt_controllers import AttentionStore 9 | 10 | def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int, prompts): 11 | out = [] 12 | attention_maps = attention_store.get_average_attention() 13 | num_pixels = res ** 2 14 | for location in from_where: 15 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 16 | if item.shape[1] == num_pixels: 17 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 18 | out.append(cross_maps) 19 | out = torch.cat(out, dim=0) 20 | out = out.sum(0) / out.shape[0] 21 | return out.cpu() 22 | 23 | 24 | def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], prompts, tokenizer, select: int = 0): 25 | tokens = tokenizer.encode(prompts[select]) 26 | decoder = tokenizer.decode 27 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select, prompts) 28 | images = [] 29 | for i in range(len(tokens)): 30 | image = attention_maps[:, :, i] 31 | image = 255 * image / image.max() 32 | image = image.unsqueeze(-1).expand(*image.shape, 3) 33 | image = image.numpy().astype(np.uint8) 34 | image = np.array(Image.fromarray(image).resize((256, 256))) 35 | image = text_under_image(image, decoder(int(tokens[i]))) 36 | images.append(image) 37 | view_images(np.stack(images, axis=0)) 38 | 39 | 40 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], 41 | max_com=10, select: int = 0): 42 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape( 43 | (res ** 2, res ** 2)) 44 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) 45 | images = [] 46 | for i in range(max_com): 47 | image = vh[i].reshape(res, res) 48 | image = image - image.min() 49 | image = 255 * image / image.max() 50 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) 51 | image = Image.fromarray(image).resize((256, 256)) 52 | image = np.array(image) 53 | images.append(image) 54 | view_images(np.concatenate(images, axis=1)) 55 | 56 | 57 | def view_images(images, num_rows=1, offset_ratio=0.02): 58 | if type(images) is list: 59 | num_empty = len(images) % num_rows 60 | elif images.ndim == 4: 61 | num_empty = images.shape[0] % num_rows 62 | else: 63 | images = [images] 64 | num_empty = 0 65 | 66 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 67 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 68 | num_items = len(images) 69 | 70 | h, w, c = images[0].shape 71 | offset = int(h * offset_ratio) 72 | num_cols = num_items // num_rows 73 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 74 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 75 | for i in range(num_rows): 76 | for j in range(num_cols): 77 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 78 | i * num_cols + j] 79 | 80 | pil_img = Image.fromarray(image_) 81 | display(pil_img) 82 | 83 | 84 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 85 | h, w, c = image.shape 86 | offset = int(h * .2) 87 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 88 | font = FONT_HERSHEY_SIMPLEX 89 | img[:h] = image 90 | textsize = getTextSize(text, font, 1, 2)[0] 91 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 92 | putText(img, text, (text_x, text_y ), font, 1, text_color, 2) 93 | return img 94 | 95 | 96 | def display(image): 97 | global display_index 98 | plt.imshow(image) 99 | plt.show() 100 | -------------------------------------------------------------------------------- /src/diffusion_model_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import numpy as np 4 | import torch 5 | from cv2 import dilate 6 | from diffusers import DDIMScheduler, StableDiffusionPipeline 7 | from tqdm import tqdm 8 | 9 | from src.attention_based_segmentation import Segmentor 10 | from src.attention_utils import show_cross_attention 11 | from src.prompt_to_prompt_controllers import DummyController, AttentionStore 12 | 13 | 14 | def get_stable_diffusion_model(args): 15 | device = torch.device(f'cuda:{args.gpu_id}') if torch.cuda.is_available() else torch.device('cpu') 16 | if args.real_image_path != "": 17 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) 18 | ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token, scheduler=scheduler).to(device) 19 | else: 20 | ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token).to(device) 21 | 22 | return ldm_stable 23 | 24 | def get_stable_diffusion_config(args): 25 | return { 26 | "low_resource": args.low_resource, 27 | "num_diffusion_steps": args.num_diffusion_steps, 28 | "guidance_scale": args.guidance_scale, 29 | "max_num_words": args.max_num_words 30 | } 31 | 32 | 33 | def generate_original_image(args, ldm_stable, ldm_stable_config, prompts, latent, uncond_embeddings): 34 | g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed) 35 | controller = AttentionStore(ldm_stable_config["low_resource"]) 36 | diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu) 37 | image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward(prompts, 38 | latent=latent, 39 | uncond_embeddings=uncond_embeddings) 40 | orig_mask = Segmentor(controller, prompts, args.num_segments, args.background_segment_threshold, background_nouns=args.background_nouns)\ 41 | .get_background_mask(args.prompt.split(' ').index("{word}") + 1) 42 | average_attention = controller.get_average_attention() 43 | return image, x_t, orig_all_latents, orig_mask, average_attention 44 | 45 | 46 | class DiffusionModelWrapper: 47 | def __init__(self, args, model, model_config, controller=None, prompt_mixing=None, generator=None): 48 | self.args = args 49 | self.model = model 50 | self.model_config = model_config 51 | self.controller = controller 52 | if self.controller is None: 53 | self.controller = DummyController() 54 | self.prompt_mixing = prompt_mixing 55 | self.device = model.device 56 | self.generator = generator 57 | 58 | self.height = 512 59 | self.width = 512 60 | 61 | self.diff_step = 0 62 | self.register_attention_control() 63 | 64 | 65 | def diffusion_step(self, latents, context, t, other_context=None): 66 | if self.model_config["low_resource"]: 67 | self.uncond_pred = True 68 | noise_pred_uncond = self.model.unet(latents, t, encoder_hidden_states=(context[0], None))["sample"] 69 | self.uncond_pred = False 70 | noise_prediction_text = self.model.unet(latents, t, encoder_hidden_states=(context[1], other_context))["sample"] 71 | else: 72 | latents_input = torch.cat([latents] * 2) 73 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=(context, other_context))["sample"] 74 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 75 | noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_prediction_text - noise_pred_uncond) 76 | latents = self.model.scheduler.step(noise_pred, t, latents)["prev_sample"] 77 | latents = self.controller.step_callback(latents) 78 | return latents 79 | 80 | 81 | def latent2image(self, latents): 82 | latents = 1 / 0.18215 * latents 83 | image = self.model.vae.decode(latents)['sample'] 84 | image = (image / 2 + 0.5).clamp(0, 1) 85 | image = image.cpu().permute(0, 2, 3, 1).numpy() 86 | image = (image * 255).astype(np.uint8) 87 | return image 88 | 89 | 90 | def init_latent(self, latent, batch_size): 91 | if latent is None: 92 | latent = torch.randn( 93 | (1, self.model.unet.in_channels, self.height // 8, self.width // 8), 94 | generator=self.generator, device=self.model.device 95 | ) 96 | latents = latent.expand(batch_size, self.model.unet.in_channels, self.height // 8, self.width // 8).to(self.device) 97 | return latent, latents 98 | 99 | 100 | def register_attention_control(self): 101 | def ca_forward(model_self, place_in_unet): 102 | to_out = model_self.to_out 103 | if type(to_out) is torch.nn.modules.container.ModuleList: 104 | to_out = model_self.to_out[0] 105 | else: 106 | to_out = model_self.to_out 107 | 108 | def forward(x, context=None, mask=None): 109 | batch_size, sequence_length, dim = x.shape 110 | h = model_self.heads 111 | q = model_self.to_q(x) 112 | is_cross = context is not None 113 | context = context if is_cross else (x, None) 114 | 115 | k = model_self.to_k(context[0]) 116 | if is_cross and self.prompt_mixing is not None: 117 | v_context = self.prompt_mixing.get_context_for_v(self.diff_step, context[0], context[1]) 118 | v = model_self.to_v(v_context) 119 | else: 120 | v = model_self.to_v(context[0]) 121 | 122 | q = model_self.reshape_heads_to_batch_dim(q) 123 | k = model_self.reshape_heads_to_batch_dim(k) 124 | v = model_self.reshape_heads_to_batch_dim(v) 125 | 126 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * model_self.scale 127 | 128 | if mask is not None: 129 | mask = mask.reshape(batch_size, -1) 130 | max_neg_value = -torch.finfo(sim.dtype).max 131 | mask = mask[:, None, :].repeat(h, 1, 1) 132 | sim.masked_fill_(~mask, max_neg_value) 133 | 134 | # attention, what we cannot get enough of 135 | attn = sim.softmax(dim=-1) 136 | if self.enbale_attn_controller_changes: 137 | attn = self.controller(attn, is_cross, place_in_unet) 138 | 139 | if is_cross and self.prompt_mixing is not None and context[1] is not None: 140 | attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size) 141 | 142 | if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None: 143 | attn = self.prompt_mixing.get_self_attn(self, self.diff_step, attn, place_in_unet, batch_size) 144 | 145 | out = torch.einsum("b i j, b j d -> b i d", attn, v) 146 | out = model_self.reshape_batch_dim_to_heads(out) 147 | return to_out(out) 148 | 149 | return forward 150 | 151 | def register_recr(net_, count, place_in_unet): 152 | if net_.__class__.__name__ == 'CrossAttention': 153 | net_.forward = ca_forward(net_, place_in_unet) 154 | return count + 1 155 | elif hasattr(net_, 'children'): 156 | for net__ in net_.children(): 157 | count = register_recr(net__, count, place_in_unet) 158 | return count 159 | 160 | cross_att_count = 0 161 | sub_nets = self.model.unet.named_children() 162 | for net in sub_nets: 163 | if "down" in net[0]: 164 | cross_att_count += register_recr(net[1], 0, "down") 165 | elif "up" in net[0]: 166 | cross_att_count += register_recr(net[1], 0, "up") 167 | elif "mid" in net[0]: 168 | cross_att_count += register_recr(net[1], 0, "mid") 169 | self.controller.num_att_layers = cross_att_count 170 | 171 | 172 | def get_text_embedding(self, prompt: List[str], max_length=None, truncation=True): 173 | text_input = self.model.tokenizer( 174 | prompt, 175 | padding="max_length", 176 | max_length=self.model.tokenizer.model_max_length if max_length is None else max_length, 177 | truncation=truncation, 178 | return_tensors="pt", 179 | ) 180 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0] 181 | max_length = text_input.input_ids.shape[-1] 182 | return text_embeddings, max_length 183 | 184 | 185 | @torch.no_grad() 186 | def forward(self, prompt: List[str], latent: Optional[torch.FloatTensor] = None, 187 | other_prompt: List[str] = None, post_background = False, orig_all_latents = None, orig_mask = None, 188 | uncond_embeddings=None, start_time=51, return_type='image'): 189 | self.enbale_attn_controller_changes = True 190 | batch_size = len(prompt) 191 | 192 | text_embeddings, max_length = self.get_text_embedding(prompt) 193 | if uncond_embeddings is None: 194 | uncond_embeddings_, _ = self.get_text_embedding([""] * batch_size, max_length=max_length, truncation=False) 195 | else: 196 | uncond_embeddings_ = None 197 | 198 | other_context = None 199 | if other_prompt is not None: 200 | other_text_embeddings, _ = self.get_text_embedding(other_prompt) 201 | other_context = other_text_embeddings 202 | 203 | latent, latents = self.init_latent(latent, batch_size) 204 | 205 | # set timesteps 206 | self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"]) 207 | all_latents = [] 208 | 209 | object_mask = None 210 | self.diff_step = 0 211 | for i, t in enumerate(tqdm(self.model.scheduler.timesteps[-start_time:])): 212 | if uncond_embeddings_ is None: 213 | context = [uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings] 214 | else: 215 | context = [uncond_embeddings_, text_embeddings] 216 | if not self.model_config["low_resource"]: 217 | context = torch.cat(context) 218 | 219 | self.down_cross_index = 0 220 | self.mid_cross_index = 0 221 | self.up_cross_index = 0 222 | latents = self.diffusion_step(latents, context, t, other_context) 223 | 224 | if post_background and self.diff_step == self.args.background_blend_timestep: 225 | object_mask = Segmentor(self.controller, 226 | prompt, 227 | self.args.num_segments, 228 | self.args.background_segment_threshold, 229 | background_nouns=self.args.background_nouns)\ 230 | .get_background_mask(self.args.prompt.split(' ').index("{word}") + 1) 231 | self.enbale_attn_controller_changes = False 232 | mask = object_mask.astype(np.bool8) + orig_mask.astype(np.bool8) 233 | mask = torch.from_numpy(mask).float().cuda() 234 | shape = (1, 1, mask.shape[0], mask.shape[1]) 235 | mask = torch.nn.Upsample(size=(64, 64), mode='nearest')(mask.view(shape)) 236 | mask_eroded = dilate(mask.cpu().numpy()[0, 0], np.ones((3, 3), np.uint8), iterations=1) 237 | mask = torch.from_numpy(mask_eroded).float().cuda().view(1, 1, 64, 64) 238 | latents = mask * latents + (1 - mask) * orig_all_latents[self.diff_step] 239 | 240 | all_latents.append(latents) 241 | self.diff_step += 1 242 | 243 | if return_type == 'image': 244 | image = self.latent2image(latents) 245 | else: 246 | image = latents 247 | 248 | return image, latent, all_latents, object_mask 249 | 250 | 251 | def show_last_cross_attention(self, res: int, from_where: List[str], prompts, select: int = 0): 252 | show_cross_attention(self.controller, res, from_where, prompts, tokenizer=self.model.tokenizer, select=select) -------------------------------------------------------------------------------- /src/null_text_inversion.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from torchvision.transforms import ToTensor 4 | from torchvision.utils import save_image 5 | from tqdm import tqdm 6 | import torch 7 | from torch.optim.adam import Adam 8 | import torch.nn.functional as nnf 9 | import numpy as np 10 | from PIL import Image 11 | 12 | 13 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 14 | if type(image_path) is str: 15 | image = np.array(Image.open(image_path))[:, :, :3] 16 | else: 17 | image = image_path 18 | h, w, c = image.shape 19 | left = min(left, w-1) 20 | right = min(right, w - left - 1) 21 | top = min(top, h - left - 1) 22 | bottom = min(bottom, h - top - 1) 23 | image = image[top:h-bottom, left:w-right] 24 | h, w, c = image.shape 25 | if h < w: 26 | offset = (w - h) // 2 27 | image = image[:, offset:offset + h] 28 | elif w < h: 29 | offset = (h - w) // 2 30 | image = image[offset:offset + w] 31 | image = np.array(Image.fromarray(image).resize((512, 512))) 32 | return image 33 | 34 | 35 | def invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path): 36 | print("Start null text inversion") 37 | null_inversion = NullInversion(ldm_stable, ldm_stable_config) 38 | (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(args.real_image_path, prompts[0], offsets=(0,0,0,0), verbose=True) 39 | save_image(ToTensor()(image_gt), f"{exp_path}/real_image.jpg") 40 | save_image(ToTensor()(image_enc), f"{exp_path}/image_enc.jpg") 41 | print("End null text inversion") 42 | return x_t, uncond_embeddings 43 | 44 | 45 | class NullInversion: 46 | 47 | def __init__(self, model, model_config): 48 | self.model = model 49 | self.model_config = model_config 50 | self.tokenizer = self.model.tokenizer 51 | self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"]) 52 | self.prompt = None 53 | self.context = None 54 | 55 | 56 | def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 57 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps 58 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 59 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod 60 | beta_prod_t = 1 - alpha_prod_t 61 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 62 | pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output 63 | prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction 64 | return prev_sample 65 | 66 | def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 67 | timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep 68 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod 69 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] 70 | beta_prod_t = 1 - alpha_prod_t 71 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 72 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 73 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 74 | return next_sample 75 | 76 | def get_noise_pred_single(self, latents, t, context): 77 | noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] 78 | return noise_pred 79 | 80 | def get_noise_pred(self, latents, t, is_forward=True, context=None): 81 | latents_input = torch.cat([latents] * 2) 82 | if context is None: 83 | context = self.context 84 | guidance_scale = 1 if is_forward else self.model_config["guidance_scale"] 85 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 86 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 87 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 88 | if is_forward: 89 | latents = self.next_step(noise_pred, t, latents) 90 | else: 91 | latents = self.prev_step(noise_pred, t, latents) 92 | return latents 93 | 94 | @torch.no_grad() 95 | def latent2image(self, latents, return_type='np'): 96 | latents = 1 / 0.18215 * latents.detach() 97 | image = self.model.vae.decode(latents)['sample'] 98 | if return_type == 'np': 99 | image = (image / 2 + 0.5).clamp(0, 1) 100 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0] 101 | image = (image * 255).astype(np.uint8) 102 | return image 103 | 104 | @torch.no_grad() 105 | def image2latent(self, image): 106 | with torch.no_grad(): 107 | if type(image) is Image: 108 | image = np.array(image) 109 | if type(image) is torch.Tensor and image.dim() == 4: 110 | latents = image 111 | else: 112 | image = torch.from_numpy(image).float() / 127.5 - 1 113 | image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device) 114 | latents = self.model.vae.encode(image)['latent_dist'].mean 115 | latents = latents * 0.18215 116 | return latents 117 | 118 | @torch.no_grad() 119 | def init_prompt(self, prompt: str): 120 | uncond_input = self.model.tokenizer( 121 | [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, 122 | return_tensors="pt" 123 | ) 124 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] 125 | text_input = self.model.tokenizer( 126 | [prompt], 127 | padding="max_length", 128 | max_length=self.model.tokenizer.model_max_length, 129 | truncation=True, 130 | return_tensors="pt", 131 | ) 132 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] 133 | self.context = torch.cat([uncond_embeddings, text_embeddings]) 134 | self.prompt = prompt 135 | 136 | @torch.no_grad() 137 | def ddim_loop(self, latent): 138 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 139 | all_latent = [latent] 140 | latent = latent.clone().detach() 141 | for i in tqdm(range(self.model_config["num_diffusion_steps"])): 142 | t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] 143 | noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) 144 | latent = self.next_step(noise_pred, t, latent) 145 | all_latent.append(latent) 146 | return all_latent 147 | 148 | @property 149 | def scheduler(self): 150 | return self.model.scheduler 151 | 152 | @torch.no_grad() 153 | def ddim_inversion(self, image): 154 | latent = self.image2latent(image) 155 | image_rec = self.latent2image(latent) 156 | ddim_latents = self.ddim_loop(latent) 157 | return image_rec, ddim_latents 158 | 159 | def null_optimization(self, latents, num_inner_steps, epsilon): 160 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 161 | uncond_embeddings_list = [] 162 | latent_cur = latents[-1] 163 | with tqdm(total=num_inner_steps * (self.model_config["num_diffusion_steps"])) as bar: 164 | for i in range(self.model_config["num_diffusion_steps"]): 165 | uncond_embeddings = uncond_embeddings.clone().detach() 166 | uncond_embeddings.requires_grad = True 167 | optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) 168 | latent_prev = latents[len(latents) - i - 2] 169 | t = self.model.scheduler.timesteps[i] 170 | with torch.no_grad(): 171 | noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) 172 | for j in range(num_inner_steps): 173 | noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) 174 | noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond) 175 | latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) 176 | loss = nnf.mse_loss(latents_prev_rec, latent_prev) 177 | optimizer.zero_grad() 178 | loss.backward() 179 | optimizer.step() 180 | loss_item = loss.item() 181 | bar.update() 182 | if loss_item < epsilon + i * 2e-5: 183 | break 184 | bar.update(num_inner_steps - j - 1) 185 | uncond_embeddings_list.append(uncond_embeddings[:1].detach()) 186 | with torch.no_grad(): 187 | context = torch.cat([uncond_embeddings, cond_embeddings]) 188 | latent_cur = self.get_noise_pred(latent_cur, t, False, context) 189 | # bar.close() 190 | return uncond_embeddings_list 191 | 192 | def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False): 193 | self.init_prompt(prompt) 194 | image_gt = load_512(image_path, *offsets) 195 | if verbose: 196 | print("DDIM inversion...") 197 | image_rec, ddim_latents = self.ddim_inversion(image_gt) 198 | if verbose: 199 | print("Null-text optimization...") 200 | uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon) 201 | return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings 202 | -------------------------------------------------------------------------------- /src/prompt_mixing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.signal import medfilt2d 3 | 4 | class PromptMixing: 5 | def __init__(self, args, object_of_interest_index, avg_cross_attn=None): 6 | self.object_of_interest_index = object_of_interest_index 7 | self.objects_to_preserve = [args.prompt.split().index(o) + 1 for o in args.objects_to_preserve] 8 | self.obj_pixels_injection_threshold = args.obj_pixels_injection_threshold 9 | 10 | self.start_other_prompt_range = args.start_prompt_range 11 | self.end_other_prompt_range = args.end_prompt_range 12 | 13 | self.start_cross_attn_replace_range = args.num_diffusion_steps 14 | self.end_cross_attn_replace_range = args.num_diffusion_steps 15 | 16 | self.start_self_attn_replace_range = 0 17 | self.end_self_attn_replace_range = args.end_preserved_obj_self_attn_masking 18 | self.remove_obj_from_self_mask = args.remove_obj_from_self_mask 19 | self.avg_cross_attn = avg_cross_attn 20 | 21 | self.low_resource = args.low_resource 22 | 23 | def get_context_for_v(self, t, context, other_context): 24 | if other_context is not None and \ 25 | self.start_other_prompt_range <= t < self.end_other_prompt_range: 26 | if self.low_resource: 27 | return other_context 28 | else: 29 | v_context = context.clone() 30 | # first half of context is for the uncoditioned image 31 | v_context[v_context.shape[0]//2:] = other_context 32 | return v_context 33 | else: 34 | return context 35 | 36 | def get_cross_attn(self, diffusion_model_wrapper, t, attn, place_in_unet, batch_size): 37 | if self.start_cross_attn_replace_range <= t < self.end_cross_attn_replace_range: 38 | if self.low_resource: 39 | attn[:,:,self.object_of_interest_index] = 0.2 * torch.from_numpy(medfilt2d(attn[:, :, self.object_of_interest_index].cpu().numpy(), kernel_size=3)).to(attn.device) + \ 40 | 0.8 * attn[:, :, self.object_of_interest_index] 41 | else: 42 | # first half of attn maps is for the uncoditioned image 43 | min_h = attn.shape[0] // 2 44 | attn[min_h:, :, self.object_of_interest_index] = 0.2 * torch.from_numpy(medfilt2d(attn[min_h:, :, self.object_of_interest_index].cpu().numpy(), kernel_size=3)).to(attn.device) + \ 45 | 0.8 * attn[min_h:, :, self.object_of_interest_index] 46 | return attn 47 | 48 | def get_self_attn(self, diffusion_model_wrapper, t, attn, place_in_unet, batch_size): 49 | if attn.shape[1] <= 32 ** 2 and \ 50 | self.avg_cross_attn is not None and \ 51 | self.start_self_attn_replace_range <= t < self.end_self_attn_replace_range: 52 | 53 | key = f"{place_in_unet}_cross" 54 | attn_index = getattr(diffusion_model_wrapper, f'{key}_index') 55 | cr = self.avg_cross_attn[key][attn_index] 56 | setattr(diffusion_model_wrapper, f'{key}_index', attn_index+1) 57 | 58 | if self.low_resource: 59 | attn = self.mask_self_attn_patches(attn, cr, batch_size) 60 | else: 61 | # first half of attn maps is for the uncoditioned image 62 | attn[attn.shape[0]//2:] = self.mask_self_attn_patches(attn[attn.shape[0]//2:], cr, batch_size//2) 63 | 64 | return attn 65 | 66 | def mask_self_attn_patches(self, self_attn, cross_attn, batch_size): 67 | h = self_attn.shape[0] // batch_size 68 | tokens = self.objects_to_preserve 69 | obj_token = self.object_of_interest_index 70 | 71 | normalized_cross_attn = cross_attn - cross_attn.min() 72 | normalized_cross_attn /= normalized_cross_attn.max() 73 | 74 | mask = torch.zeros_like(self_attn[0]) 75 | for tk in tokens: 76 | mask_tk_in = torch.unique((normalized_cross_attn[:,:,tk] > self.obj_pixels_injection_threshold).nonzero(as_tuple=True)[1]) 77 | mask[mask_tk_in, :] = 1 78 | mask[:, mask_tk_in] = 1 79 | 80 | if self.remove_obj_from_self_mask: 81 | obj_patches = torch.unique((normalized_cross_attn[:,:,obj_token] > self.obj_pixels_injection_threshold).nonzero(as_tuple=True)[1]) 82 | mask[obj_patches, :] = 0 83 | mask[:, obj_patches] = 0 84 | 85 | self_attn[h:] = self_attn[h:] * (1 - mask) + self_attn[:h].repeat(batch_size - 1, 1, 1) * mask 86 | return self_attn -------------------------------------------------------------------------------- /src/prompt_to_prompt_controllers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import abc 4 | from typing import Optional, Union, Tuple, Dict 5 | import src.seq_aligner as seq_aligner 6 | 7 | 8 | class AttentionControl(abc.ABC): 9 | 10 | def step_callback(self, x_t): 11 | return x_t 12 | 13 | def between_steps(self): 14 | return 15 | 16 | @property 17 | def num_uncond_att_layers(self): 18 | return self.num_att_layers if self.low_resource else 0 19 | 20 | @abc.abstractmethod 21 | def forward(self, attn, is_cross: bool, place_in_unet: str): 22 | raise NotImplementedError 23 | 24 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 25 | if self.cur_att_layer >= self.num_uncond_att_layers: 26 | if self.low_resource: 27 | attn = self.forward(attn, is_cross, place_in_unet) 28 | else: 29 | h = attn.shape[0] 30 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 31 | self.cur_att_layer += 1 32 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 33 | self.cur_att_layer = 0 34 | self.cur_step += 1 35 | self.between_steps() 36 | return attn 37 | 38 | def reset(self): 39 | self.cur_step = 0 40 | self.cur_att_layer = 0 41 | 42 | def __init__(self, low_resource): 43 | self.cur_step = 0 44 | self.num_att_layers = -1 45 | self.cur_att_layer = 0 46 | self.low_resource = low_resource 47 | 48 | 49 | class EmptyControl(AttentionControl): 50 | 51 | def forward(self, attn, is_cross: bool, place_in_unet: str): 52 | return attn 53 | 54 | 55 | class DummyController: 56 | def __call__(self, *args): 57 | return args[0] 58 | 59 | def __init__(self): 60 | self.num_att_layers = 0 61 | 62 | 63 | class AttentionStore(AttentionControl): 64 | 65 | @staticmethod 66 | def get_empty_store(): 67 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 68 | "down_self": [], "mid_self": [], "up_self": []} 69 | 70 | def forward(self, attn, is_cross: bool, place_in_unet: str): 71 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 72 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 73 | self.step_store[key].append(attn) 74 | return attn 75 | 76 | def between_steps(self): 77 | if len(self.attention_store) == 0: 78 | self.attention_store = self.step_store 79 | else: 80 | for key in self.attention_store: 81 | for i in range(len(self.attention_store[key])): 82 | self.attention_store[key][i] += self.step_store[key][i] 83 | self.step_store = self.get_empty_store() 84 | 85 | def get_average_attention(self): 86 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in 87 | self.attention_store} 88 | return average_attention 89 | 90 | def reset(self): 91 | super(AttentionStore, self).reset() 92 | self.step_store = self.get_empty_store() 93 | self.attention_store = {} 94 | 95 | def __init__(self, low_resource): 96 | super(AttentionStore, self).__init__(low_resource) 97 | self.step_store = self.get_empty_store() 98 | self.attention_store = {} 99 | 100 | 101 | class AttentionControlEdit(AttentionStore, abc.ABC): 102 | 103 | def step_callback(self, x_t): 104 | return x_t 105 | 106 | def replace_self_attention(self, attn_base, att_replace): 107 | if att_replace.shape[2] <= 16 ** 2: 108 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 109 | else: 110 | return att_replace 111 | 112 | @abc.abstractmethod 113 | def replace_cross_attention(self, attn_base, att_replace): 114 | raise NotImplementedError 115 | 116 | def forward(self, attn, is_cross: bool, place_in_unet: str): 117 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 118 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 119 | h = attn.shape[0] // (self.batch_size) 120 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 121 | attn_base, attn_repalce = attn[0], attn[1:] 122 | if is_cross: 123 | alpha_words = self.cross_replace_alpha[self.cur_step] 124 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + ( 125 | 1 - alpha_words) * attn_repalce 126 | attn[1:] = attn_repalce_new 127 | else: 128 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) 129 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 130 | return attn 131 | 132 | def __init__(self, prompts, tokenizer, device, low_resource, num_steps: int, 133 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 134 | self_replace_steps: Union[float, Tuple[float, float]]): 135 | super(AttentionControlEdit, self).__init__(low_resource) 136 | self.batch_size = len(prompts) 137 | self.tokenizer = tokenizer 138 | self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, 139 | self.tokenizer).to(device) 140 | if type(self_replace_steps) is float: 141 | self_replace_steps = 0, self_replace_steps 142 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 143 | 144 | 145 | class AttentionReplace(AttentionControlEdit): 146 | 147 | def replace_cross_attention(self, attn_base, att_replace): 148 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper.to(attn_base.dtype)) 149 | 150 | def __init__(self, prompts, tokenizer, device, low_resource, num_steps: int, cross_replace_steps: float, self_replace_steps: float): 151 | super(AttentionReplace, self).__init__(prompts, tokenizer, device, low_resource, num_steps, cross_replace_steps, self_replace_steps) 152 | self.mapper = seq_aligner.get_replacement_mapper(prompts, self.tokenizer).to(device) 153 | 154 | 155 | def get_word_inds(text: str, word_place: int, tokenizer): 156 | split_text = text.split(" ") 157 | if type(word_place) is str: 158 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 159 | elif type(word_place) is int: 160 | word_place = [word_place] 161 | out = [] 162 | if len(word_place) > 0: 163 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 164 | cur_len, ptr = 0, 0 165 | 166 | for i in range(len(words_encode)): 167 | cur_len += len(words_encode[i]) 168 | if ptr in word_place: 169 | out.append(i + 1) 170 | if cur_len >= len(split_text[ptr]): 171 | ptr += 1 172 | cur_len = 0 173 | return np.array(out) 174 | 175 | 176 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None): 177 | if type(bounds) is float: 178 | bounds = 0, bounds 179 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 180 | if word_inds is None: 181 | word_inds = torch.arange(alpha.shape[2]) 182 | alpha[: start, prompt_ind, word_inds] = 0 183 | alpha[start: end, prompt_ind, word_inds] = 1 184 | alpha[end:, prompt_ind, word_inds] = 0 185 | return alpha 186 | 187 | 188 | def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 189 | tokenizer, max_num_words=77): 190 | if type(cross_replace_steps) is not dict: 191 | cross_replace_steps = {"default_": cross_replace_steps} 192 | if "default_" not in cross_replace_steps: 193 | cross_replace_steps["default_"] = (0., 1.) 194 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 195 | for i in range(len(prompts) - 1): 196 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 197 | i) 198 | for key, item in cross_replace_steps.items(): 199 | if key != "default_": 200 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 201 | for i, ind in enumerate(inds): 202 | if len(ind) > 0: 203 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 204 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words 205 | return alpha_time_words -------------------------------------------------------------------------------- /src/prompt_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | 7 | def get_topk_similar_words(model, prompt, base_word, vocab, k=30): 8 | text_input = model.tokenizer( 9 | [prompt.format(word=base_word)], 10 | padding="max_length", 11 | max_length=model.tokenizer.model_max_length, 12 | truncation=True, 13 | return_tensors="pt", 14 | ) 15 | with torch.no_grad(): 16 | encoder_output = model.text_encoder(text_input.input_ids.to(model.device)) 17 | full_prompt_embedding = encoder_output.pooler_output 18 | full_prompt_embedding = full_prompt_embedding / full_prompt_embedding.norm(p=2, dim=-1, keepdim=True) 19 | 20 | prompts = [prompt.format(word=word) for word in vocab] 21 | batch_size = 1000 22 | all_prompts_embeddings = [] 23 | for i in tqdm(range(0, len(prompts), batch_size)): 24 | curr_prompts = prompts[i:i + batch_size] 25 | with torch.no_grad(): 26 | text_input = model.tokenizer( 27 | curr_prompts, 28 | padding="max_length", 29 | max_length=model.tokenizer.model_max_length, 30 | truncation=True, 31 | return_tensors="pt", 32 | ) 33 | curr_embeddings = model.text_encoder(text_input.input_ids.to(model.device)).pooler_output 34 | all_prompts_embeddings.append(curr_embeddings) 35 | 36 | all_prompts_embeddings = torch.cat(all_prompts_embeddings) 37 | all_prompts_embeddings = all_prompts_embeddings / all_prompts_embeddings.norm(p=2, dim=-1, keepdim=True) 38 | prompts_similarities = all_prompts_embeddings.matmul(full_prompt_embedding.view(-1, 1)) 39 | sorted_prompts_similarities = np.flip(prompts_similarities.cpu().numpy().reshape(-1).argsort()) 40 | 41 | print(f"prompt: {prompt}") 42 | print(f"initial word: {base_word}") 43 | print(f"TOP {k} SIMILAR WORDS:") 44 | similar_words = [vocab[index] for index in sorted_prompts_similarities[:k]] 45 | print(similar_words) 46 | return similar_words 47 | 48 | def get_proxy_words(args, ldm_stable): 49 | if len(args.proxy_words) > 0: 50 | return [args.object_of_interest] + args.proxy_words 51 | vocab = list(json.load(open("vocab.json")).keys()) 52 | vocab = [word for word in vocab if word.isalpha() and len(word) > 1] 53 | filtered_vocab = get_topk_similar_words(ldm_stable, "a photo of a {word}", args.object_of_interest, vocab, k=50) 54 | proxy_words = get_topk_similar_words(ldm_stable, args.prompt, args.object_of_interest, filtered_vocab, k=args.number_of_variations) 55 | if proxy_words[0] != args.object_of_interest: 56 | proxy_words = [args.object_of_interest] + proxy_words 57 | 58 | return proxy_words 59 | 60 | def get_proxy_prompts(args, ldm_stable): 61 | proxy_words = get_proxy_words(args, ldm_stable) 62 | prompts = [args.prompt.format(word=args.object_of_interest)] 63 | proxy_prompts = [{"word": word, "prompt": args.prompt.format(word=word)} for word in proxy_words] 64 | return proxy_words, prompts, proxy_prompts -------------------------------------------------------------------------------- /src/seq_aligner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class ScoreParams: 19 | 20 | def __init__(self, gap, match, mismatch): 21 | self.gap = gap 22 | self.match = match 23 | self.mismatch = mismatch 24 | 25 | def mis_match_char(self, x, y): 26 | if x != y: 27 | return self.mismatch 28 | else: 29 | return self.match 30 | 31 | 32 | def get_matrix(size_x, size_y, gap): 33 | matrix = [] 34 | for i in range(len(size_x) + 1): 35 | sub_matrix = [] 36 | for j in range(len(size_y) + 1): 37 | sub_matrix.append(0) 38 | matrix.append(sub_matrix) 39 | for j in range(1, len(size_y) + 1): 40 | matrix[0][j] = j*gap 41 | for i in range(1, len(size_x) + 1): 42 | matrix[i][0] = i*gap 43 | return matrix 44 | 45 | 46 | def get_matrix(size_x, size_y, gap): 47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 50 | return matrix 51 | 52 | 53 | def get_traceback_matrix(size_x, size_y): 54 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) 55 | matrix[0, 1:] = 1 56 | matrix[1:, 0] = 2 57 | matrix[0, 0] = 4 58 | return matrix 59 | 60 | 61 | def global_align(x, y, score): 62 | matrix = get_matrix(len(x), len(y), score.gap) 63 | trace_back = get_traceback_matrix(len(x), len(y)) 64 | for i in range(1, len(x) + 1): 65 | for j in range(1, len(y) + 1): 66 | left = matrix[i, j - 1] + score.gap 67 | up = matrix[i - 1, j] + score.gap 68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 69 | matrix[i, j] = max(left, up, diag) 70 | if matrix[i, j] == left: 71 | trace_back[i, j] = 1 72 | elif matrix[i, j] == up: 73 | trace_back[i, j] = 2 74 | else: 75 | trace_back[i, j] = 3 76 | return matrix, trace_back 77 | 78 | 79 | def get_aligned_sequences(x, y, trace_back): 80 | x_seq = [] 81 | y_seq = [] 82 | i = len(x) 83 | j = len(y) 84 | mapper_y_to_x = [] 85 | while i > 0 or j > 0: 86 | if trace_back[i, j] == 3: 87 | x_seq.append(x[i-1]) 88 | y_seq.append(y[j-1]) 89 | i = i-1 90 | j = j-1 91 | mapper_y_to_x.append((j, i)) 92 | elif trace_back[i][j] == 1: 93 | x_seq.append('-') 94 | y_seq.append(y[j-1]) 95 | j = j-1 96 | mapper_y_to_x.append((j, -1)) 97 | elif trace_back[i][j] == 2: 98 | x_seq.append(x[i-1]) 99 | y_seq.append('-') 100 | i = i-1 101 | elif trace_back[i][j] == 4: 102 | break 103 | mapper_y_to_x.reverse() 104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 105 | 106 | 107 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 108 | x_seq = tokenizer.encode(x) 109 | y_seq = tokenizer.encode(y) 110 | score = ScoreParams(0, 1, -1) 111 | matrix, trace_back = global_align(x_seq, y_seq, score) 112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 113 | alphas = torch.ones(max_len) 114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 115 | mapper = torch.zeros(max_len, dtype=torch.int64) 116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1] 117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) 118 | return mapper, alphas 119 | 120 | 121 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 122 | x_seq = prompts[0] 123 | mappers, alphas = [], [] 124 | for i in range(1, len(prompts)): 125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 126 | mappers.append(mapper) 127 | alphas.append(alpha) 128 | return torch.stack(mappers), torch.stack(alphas) 129 | 130 | 131 | def get_word_inds(text: str, word_place: int, tokenizer): 132 | split_text = text.split(" ") 133 | if type(word_place) is str: 134 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 135 | elif type(word_place) is int: 136 | word_place = [word_place] 137 | out = [] 138 | if len(word_place) > 0: 139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 140 | cur_len, ptr = 0, 0 141 | 142 | for i in range(len(words_encode)): 143 | cur_len += len(words_encode[i]) 144 | if ptr in word_place: 145 | out.append(i + 1) 146 | if cur_len >= len(split_text[ptr]): 147 | ptr += 1 148 | cur_len = 0 149 | return np.array(out) 150 | 151 | 152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 153 | words_x = x.split(' ') 154 | words_y = y.split(' ') 155 | if len(words_x) != len(words_y): 156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" 157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") 158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 161 | mapper = np.zeros((max_len, max_len)) 162 | i = j = 0 163 | cur_inds = 0 164 | while i < max_len and j < max_len: 165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 167 | if len(inds_source_) == len(inds_target_): 168 | mapper[inds_source_, inds_target_] = 1 169 | else: 170 | ratio = 1 / len(inds_target_) 171 | for i_t in inds_target_: 172 | mapper[inds_source_, i_t] = ratio 173 | cur_inds += 1 174 | i += len(inds_source_) 175 | j += len(inds_target_) 176 | elif cur_inds < len(inds_source): 177 | mapper[i, j] = 1 178 | i += 1 179 | j += 1 180 | else: 181 | mapper[j, j] = 1 182 | i += 1 183 | j += 1 184 | 185 | return torch.from_numpy(mapper).float() 186 | 187 | 188 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 189 | x_seq = prompts[0] 190 | mappers = [] 191 | for i in range(1, len(prompts)): 192 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 193 | mappers.append(mapper) 194 | return torch.stack(mappers) 195 | 196 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | text-align: center; 3 | } 4 | --------------------------------------------------------------------------------