├── .gitignore
├── README.md
├── docs
├── generated.jpg
├── real.jpg
├── seg.jpg
└── teaser.jpg
├── gradio_app.py
├── lpm_env.yml
├── main.py
├── real_images
├── lamp_simple.png
└── rinon_cat.jpg
├── requirements.txt
├── run_segmentation.py
├── src
├── attention_based_segmentation.py
├── attention_utils.py
├── diffusion_model_wrapper.py
├── null_text_inversion.py
├── prompt_mixing.py
├── prompt_to_prompt_controllers.py
├── prompt_utils.py
└── seq_aligner.py
├── style.css
└── vocab.json
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | results/
3 | .vscode/
4 | segmentation_results/
5 | .idea/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Localizing Object-level Shape Variations with Text-to-Image Diffusion Models (ICCV 2023)
2 |
3 | > **Or Patashnik, Daniel Garibi, Idan Azuri, Hadar Averbuch-Elor, Daniel Cohen-Or**
4 | >
5 | > Text-to-image models give rise to workflows which often begin with an exploration step, where users sift through a large collection of generated images. The global nature of the text-to-image generation process prevents users from narrowing their exploration to a particular object in the image. In this paper, we present a technique to generate a collection of images that depicts variations in the shape of a specific object, enabling an object-level shape exploration process. Creating plausible variations is challenging as it requires control over the shape of the generated object while respecting its semantics. A particular challenge when generating object variations is accurately localizing the manipulation applied over the object's shape. We introduce a prompt-mixing technique that switches between prompts along the denoising process to attain a variety of shape choices. To localize the image-space operation, we present two techniques that use the self-attention layers in conjunction with the cross-attention layers. Moreover, we show that these localization techniques are general and effective beyond the scope of generating object variations. Extensive results and comparisons demonstrate the effectiveness of our method in generating object variations, and the competence of our localization techniques.
6 |
7 |
8 |
9 |
10 | [](https://huggingface.co/spaces/orpatashnik/local-prompt-mixing)
11 |
12 |
13 |
14 |
15 |
16 | ## Description
17 | Official implementation of our Localizing Object-level Shape Variations with Text-to-Image Diffusion Models paper.
18 |
19 | ## Setup
20 |
21 | ### Environment
22 | Our code builds on the requirement of the official [Stable Diffusion repository](https://github.com/CompVis/stable-diffusion). To set up the environment, please run:
23 |
24 | ```
25 | conda env create -f lpm_env.yml
26 | conda activate lpm
27 | ```
28 |
29 | Then, please run in python:
30 | ```python
31 | import nltk
32 | nltk.download('punkt')
33 | nltk.download('averaged_perceptron_tagger')
34 | ```
35 |
36 | This project has a gradio [demo](https://huggingface.co/spaces/orpatashnik/local-prompt-mixing) deployed in HuggingFace.
37 | To run the demo locally, run the following:
38 | ```shell
39 | gradio gradio_app.py
40 | ```
41 | Then, you can connect to the local demo by browsing to `http://localhost:7860/`.
42 |
43 | ## Prompt Mix & Match Usage
44 |
45 | ### Generated Images
46 |
47 |
48 |
49 |
50 | Example generations by Stable Diffusion with variations outputted by Prompt Mix & Match.
51 |
52 |
53 |
54 | To generate an image, you can simply run the `main.py` script. For example,
55 | ```
56 | python main.py --seed 48 --prompt "hamster eating {word} on the beach" --object_of_interest "watermelon" --background_nouns=["beach","hamster"]
57 | ```
58 | Notes:
59 |
60 | - To choose the amount of required variations, specify: `--number_of_variations 20`.
61 | - You may use your own proxy words instead of the auto-generated words. For example `--proxy_words=["pizza","ball","cube"]`.
62 | - In order to change the shape inerval ($T_3$ and $T_2$ in the paper), specify: `--start_prompt_range 5 --end_prompt_range 15`.
63 | - You may use self-attention localization, for example `--objects_to_preserve=["hamster"]`.
64 | - You may also remove the object of interest from the self attention mask by specifing `--remove_obj_from_self_mask True` (this flag is `True` by default)
65 |
66 | All generated images will be saved to the path `"{exp_dir}/{prompt}/seed={seed}_{exp_name}/"`:
67 | ```
68 | {exp_dir}/
69 | |-- {prompt}/
70 | | |-- seed={seed}_{exp_name}/
71 | | |-- {object_of_interest}.jpg
72 | | |-- {proxy_words[0]}.jpg
73 | | |-- {proxy_words[1]}.jpg
74 | | |-- ...
75 | | |-- grid.jpg
76 | | |-- opt.json
77 | | |-- seed={seed}_{exp_name}/
78 | | |-- {object_of_interest}.jpg
79 | | |-- {proxy_words[0]}.jpg
80 | | |-- {proxy_words[1]}.jpg
81 | | |-- ...
82 | | |-- grid.jpg
83 | | |-- opt.json
84 | ...
85 | ```
86 | The default values are `--exp_dir "results" --exp_name ""`.
87 | ### Real Images
88 |
89 |
90 |
91 |
92 | Example real image with variations outputted by Prompt Mix & Match.
93 |
94 |
95 | To generate an image, you can simply run the `main.py` script. For example,
96 | ```
97 | python main.py --real_image_path "real_images/lamp_simple.png" --prompt "A table below a {word}" --object_of_interest "lamp" --objects_to_preserve=["table"] --background_nouns=["table"]
98 | ```
99 |
100 | All generated images will be saved in the same format as for generated image.
101 |
102 | ## Self-Segmentation
103 |
104 |
105 |
106 |
107 | Example segmantation of a real image.
108 |
109 |
110 | To get segmentation of an image, you can simply run the `run_segmentation.py` script.
111 | The paraments of the segmentation located inside the script at the `SegmentationConfig` class:
112 | - To use a real image, specify its path at the attribute `real_image_path`.
113 | - You may change the number of segments with the `num_segments` param.
114 |
115 | The outputs will be saved to the path `"{exp_path}/"`:
116 | ```
117 | {exp_path}/
118 | | |-- real_image.jpg
119 | | |-- image_enc.jpg
120 | | |-- image_rec.jpg
121 | | |-- segmentation.jpg
122 | ```
123 | - `real_image.jpg` - Original image.
124 | - `image_enc.jpg` - Reconstration of the image by stable autoencoder only.
125 | - `image_rec.jpg` - Reconstration of the image by stable diffusion full pipeline.
126 | - `segmentation.jpg` - Segmentation output.
127 |
128 |
129 | ## Acknowledgements
130 | This code is builds on the code from the [diffusers](https://github.com/huggingface/diffusers) library as well as the [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/) codebase.
131 |
132 | ## Citation
133 |
134 | If you use this code for your research, please cite our paper:
135 |
136 | ```
137 | @InProceedings{patashnik2023localizing,
138 | author = {Patashnik, Or and Garibi, Daniel and Azuri, Idan and Averbuch-Elor, Hadar and Cohen-Or, Daniel},
139 | title = {Localizing Object-level Shape Variations with Text-to-Image Diffusion Models},
140 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
141 | year = {2023}
142 | }
143 | ```
144 |
--------------------------------------------------------------------------------
/docs/generated.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/docs/generated.jpg
--------------------------------------------------------------------------------
/docs/real.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/docs/real.jpg
--------------------------------------------------------------------------------
/docs/seg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/docs/seg.jpg
--------------------------------------------------------------------------------
/docs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/docs/teaser.jpg
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import gradio as gr
4 | import nltk
5 | import numpy as np
6 | from PIL import Image
7 |
8 | nltk.download('punkt')
9 | nltk.download('averaged_perceptron_tagger')
10 |
11 | from main import LPMConfig, main, setup
12 |
13 | DESCRIPTION = '''# Localizing Object-level Shape Variations with Text-to-Image Diffusion Models
14 | This is a demo for our ''Localizing Object-level Shape Variations with Text-to-Image Diffusion Models'' [paper](https://arxiv.org/abs/2303.11306).
15 | We introduce a method that generates object-level shape variation for a given image.
16 | This demo supports both generated images and real images. To modify a real image, please upload it to the input image block and provide a prompt that describes its contents.
17 |
18 | '''
19 |
20 | stable, stable_config = setup(LPMConfig())
21 |
22 | def main_pipeline(
23 | prompt: str,
24 | object_of_interest: str,
25 | proxy_words: str,
26 | number_of_variations: int,
27 | start_prompt_range: int,
28 | end_prompt_range: int,
29 | objects_to_preserve: str,
30 | background_nouns: str,
31 | seed: int,
32 | input_image: str):
33 | prompt = prompt.replace(object_of_interest, '{word}')
34 | proxy_words = proxy_words.split(',') if proxy_words != '' else []
35 | objects_to_preserve = objects_to_preserve.split(',') if objects_to_preserve != '' else []
36 | background_nouns = background_nouns.split(',') if background_nouns != '' else []
37 | args = LPMConfig(
38 | seed=seed,
39 | prompt=prompt,
40 | object_of_interest=object_of_interest,
41 | proxy_words=proxy_words,
42 | number_of_variations=number_of_variations,
43 | start_prompt_range=start_prompt_range,
44 | end_prompt_range=end_prompt_range,
45 | objects_to_preserve=objects_to_preserve,
46 | background_nouns=background_nouns,
47 | real_image_path="" if input_image is None else input_image
48 | )
49 |
50 | result_images, result_proxy_words = main(stable, stable_config, args)
51 | result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
52 | result_images = [(im * 255).astype(np.uint8) for im in result_images]
53 | result_images = [Image.fromarray(im) for im in result_images]
54 |
55 | return result_images, ",".join(result_proxy_words)
56 |
57 |
58 | with gr.Blocks(css='style.css') as demo:
59 | gr.Markdown(DESCRIPTION)
60 |
61 | gr.HTML(
62 | '''
63 |
Duplicate the Space to run privately without waiting in queue''')
64 |
65 | with gr.Row():
66 | with gr.Column():
67 | input_image = gr.Image(
68 | label="Input image (optional)",
69 | type="filepath"
70 | )
71 | prompt = gr.Text(
72 | label='Prompt',
73 | max_lines=1,
74 | placeholder='A table below a lamp',
75 | )
76 | object_of_interest = gr.Text(
77 | label='Object of interest',
78 | max_lines=1,
79 | placeholder='lamp',
80 | )
81 | proxy_words = gr.Text(
82 | label='Proxy words - words used to obtain variations (a comma-separated list of words, can leave empty)',
83 | max_lines=1,
84 | placeholder=''
85 | )
86 | number_of_variations = gr.Slider(
87 | label='Number of variations (used only for automatic proxy-words)',
88 | minimum=2,
89 | maximum=30,
90 | value=7,
91 | step=1
92 | )
93 | start_prompt_range = gr.Slider(
94 | label='Number of steps before starting shape interval',
95 | minimum=0,
96 | maximum=50,
97 | value=7,
98 | step=1
99 | )
100 | end_prompt_range = gr.Slider(
101 | label='Number of steps before ending shape interval',
102 | minimum=1,
103 | maximum=50,
104 | value=17,
105 | step=1
106 | )
107 | objects_to_preserve = gr.Text(
108 | label='Words corresponding to objects to preserve (a comma-separated list of words, can leave empty)',
109 | max_lines=1,
110 | placeholder='table',
111 | )
112 | background_nouns = gr.Text(
113 | label='Words corresponding to objects that should be copied from original image (a comma-separated list of words, can leave empty)',
114 | max_lines=1,
115 | placeholder='',
116 | )
117 | seed = gr.Slider(
118 | label='Seed',
119 | minimum=1,
120 | maximum=100000,
121 | value=0,
122 | step=1
123 | )
124 |
125 | run_button = gr.Button('Generate')
126 | with gr.Column():
127 | result = gr.Gallery(label='Result').style(grid=4)
128 | proxy_words_result = gr.Text(label='Used proxy words')
129 |
130 | examples = [
131 | [
132 | "hamster eating watermelon on the beach",
133 | "watermelon",
134 | "",
135 | 7,
136 | 6,
137 | 16,
138 | "",
139 | "hamster,beach",
140 | 48,
141 | None
142 | ],
143 | [
144 | "A decorated lamp in the livingroom",
145 | "lamp",
146 | "",
147 | 7,
148 | 4,
149 | 14,
150 | "livingroom",
151 | "",
152 | 42,
153 | None
154 | ],
155 | [
156 | "a snake in the field eats an apple",
157 | "snake",
158 | "",
159 | 7,
160 | 7,
161 | 17,
162 | "apple",
163 | "apple,field",
164 | 10,
165 | None
166 | ]
167 | ]
168 |
169 | gr.Examples(examples=examples,
170 | inputs=[
171 | prompt,
172 | object_of_interest,
173 | proxy_words,
174 | number_of_variations,
175 | start_prompt_range,
176 | end_prompt_range,
177 | objects_to_preserve,
178 | background_nouns,
179 | seed,
180 | input_image
181 | ],
182 | outputs=[
183 | result,
184 | proxy_words_result
185 | ],
186 | fn=main_pipeline,
187 | cache_examples=False)
188 |
189 |
190 | inputs = [
191 | prompt,
192 | object_of_interest,
193 | proxy_words,
194 | number_of_variations,
195 | start_prompt_range,
196 | end_prompt_range,
197 | objects_to_preserve,
198 | background_nouns,
199 | seed,
200 | input_image
201 | ]
202 | outputs = [
203 | result,
204 | proxy_words_result
205 | ]
206 | run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
207 |
208 | demo.queue(max_size=50).launch(share=False)
--------------------------------------------------------------------------------
/lpm_env.yml:
--------------------------------------------------------------------------------
1 | name: lpm
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - anaconda
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=conda_forge
10 | - _openmp_mutex=4.5=2_gnu
11 | - anyio=3.6.2=pyhd8ed1ab_0
12 | - argon2-cffi=21.3.0=pyhd8ed1ab_0
13 | - argon2-cffi-bindings=21.2.0=py310h5764c6d_3
14 | - asttokens=2.2.1=pyhd8ed1ab_0
15 | - attrs=22.2.0=pyh71513ae_0
16 | - backcall=0.2.0=pyh9f0ad1d_0
17 | - backports=1.0=pyhd8ed1ab_3
18 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
19 | - beautifulsoup4=4.11.2=pyha770c72_0
20 | - blas=1.0=mkl
21 | - bleach=6.0.0=pyhd8ed1ab_0
22 | - brotli=1.0.9=h166bdaf_7
23 | - brotli-bin=1.0.9=h166bdaf_7
24 | - brotlipy=0.7.0=py310h7f8727e_1002
25 | - bzip2=1.0.8=h7b6447c_0
26 | - ca-certificates=2022.12.7=ha878542_0
27 | - certifi=2022.12.7=pyhd8ed1ab_0
28 | - cffi=1.15.1=py310h5eee18b_3
29 | - chardet=4.0.0=py310h06a4308_1003
30 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
31 | - comm=0.1.2=pyhd8ed1ab_0
32 | - contourpy=1.0.5=py310hdb19cb5_0
33 | - cryptography=38.0.4=py310h9ce1e76_0
34 | - cuda=11.7.1=0
35 | - cuda-cccl=11.7.91=0
36 | - cuda-command-line-tools=11.7.1=0
37 | - cuda-compiler=11.7.1=0
38 | - cuda-cudart=11.7.99=0
39 | - cuda-cudart-dev=11.7.99=0
40 | - cuda-cuobjdump=11.7.91=0
41 | - cuda-cupti=11.7.101=0
42 | - cuda-cuxxfilt=11.7.91=0
43 | - cuda-demo-suite=12.0.76=0
44 | - cuda-documentation=12.0.76=0
45 | - cuda-driver-dev=11.7.99=0
46 | - cuda-gdb=12.0.90=0
47 | - cuda-libraries=11.7.1=0
48 | - cuda-libraries-dev=11.7.1=0
49 | - cuda-memcheck=11.8.86=0
50 | - cuda-nsight=12.0.78=0
51 | - cuda-nsight-compute=12.0.0=0
52 | - cuda-nvcc=11.7.99=0
53 | - cuda-nvdisasm=12.0.76=0
54 | - cuda-nvml-dev=11.7.91=0
55 | - cuda-nvprof=12.0.90=0
56 | - cuda-nvprune=11.7.91=0
57 | - cuda-nvrtc=11.7.99=0
58 | - cuda-nvrtc-dev=11.7.99=0
59 | - cuda-nvtx=11.7.91=0
60 | - cuda-nvvp=12.0.90=0
61 | - cuda-runtime=11.7.1=0
62 | - cuda-sanitizer-api=12.0.90=0
63 | - cuda-toolkit=11.7.1=0
64 | - cuda-tools=11.7.1=0
65 | - cuda-visual-tools=11.7.1=0
66 | - cycler=0.11.0=pyhd8ed1ab_0
67 | - dbus=1.13.18=hb2f20db_0
68 | - debugpy=1.5.1=py310h295c915_0
69 | - decorator=5.1.1=pyhd8ed1ab_0
70 | - defusedxml=0.7.1=pyhd8ed1ab_0
71 | - entrypoints=0.4=pyhd8ed1ab_0
72 | - executing=1.2.0=pyhd8ed1ab_0
73 | - expat=2.2.10=h9c3ff4c_0
74 | - ffmpeg=4.3=hf484d3e_0
75 | - flit-core=3.6.0=pyhd3eb1b0_0
76 | - fontconfig=2.14.1=hef1e5e3_0
77 | - fonttools=4.25.0=pyhd3eb1b0_0
78 | - freetype=2.12.1=h4a9f257_0
79 | - gds-tools=1.5.0.59=0
80 | - giflib=5.2.1=h7b6447c_0
81 | - glib=2.69.1=he621ea3_2
82 | - gmp=6.2.1=h295c915_3
83 | - gnutls=3.6.15=he1e5248_0
84 | - gst-plugins-base=1.14.0=h8213a91_2
85 | - gstreamer=1.14.0=h28cd5cc_2
86 | - icu=58.2=hf484d3e_1000
87 | - idna=3.4=py310h06a4308_0
88 | - importlib-metadata=6.0.0=pyha770c72_0
89 | - importlib_resources=5.10.2=pyhd8ed1ab_0
90 | - intel-openmp=2021.4.0=h06a4308_3561
91 | - ipykernel=6.19.2=py310h2f386ee_0
92 | - ipython=8.8.0=py310h06a4308_0
93 | - ipython_genutils=0.2.0=py_1
94 | - jedi=0.18.2=pyhd8ed1ab_0
95 | - jinja2=3.1.2=pyhd8ed1ab_1
96 | - jpeg=9e=h7f8727e_0
97 | - jsonschema=4.17.3=pyhd8ed1ab_0
98 | - jupyter_client=7.3.4=pyhd8ed1ab_0
99 | - jupyter_core=4.12.0=py310hff52083_0
100 | - jupyter_server=1.23.5=pyhd8ed1ab_0
101 | - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
102 | - keyutils=1.6.1=h166bdaf_0
103 | - kiwisolver=1.4.4=py310h6a678d5_0
104 | - krb5=1.19.3=h3790be6_0
105 | - lame=3.100=h7b6447c_0
106 | - lcms2=2.12=h3be6417_0
107 | - ld_impl_linux-64=2.38=h1181459_1
108 | - lerc=3.0=h295c915_0
109 | - libbrotlicommon=1.0.9=h166bdaf_7
110 | - libbrotlidec=1.0.9=h166bdaf_7
111 | - libbrotlienc=1.0.9=h166bdaf_7
112 | - libclang=10.0.1=default_hb85057a_2
113 | - libcublas=11.10.3.66=0
114 | - libcublas-dev=11.10.3.66=0
115 | - libcufft=10.7.2.124=h4fbf590_0
116 | - libcufft-dev=10.7.2.124=h98a8f43_0
117 | - libcufile=1.5.0.59=0
118 | - libcufile-dev=1.5.0.59=0
119 | - libcurand=10.3.1.50=0
120 | - libcurand-dev=10.3.1.50=0
121 | - libcusolver=11.4.0.1=0
122 | - libcusolver-dev=11.4.0.1=0
123 | - libcusparse=11.7.4.91=0
124 | - libcusparse-dev=11.7.4.91=0
125 | - libdeflate=1.8=h7f8727e_5
126 | - libedit=3.1.20191231=he28a2e2_2
127 | - libevent=2.1.12=h8f2d780_0
128 | - libffi=3.4.2=h6a678d5_6
129 | - libgcc-ng=12.2.0=h65d4601_19
130 | - libgomp=12.2.0=h65d4601_19
131 | - libiconv=1.16=h7f8727e_2
132 | - libidn2=2.3.2=h7f8727e_0
133 | - libllvm10=10.0.1=he513fc3_3
134 | - libnpp=11.7.4.75=0
135 | - libnpp-dev=11.7.4.75=0
136 | - libnvjpeg=11.8.0.2=0
137 | - libnvjpeg-dev=11.8.0.2=0
138 | - libpng=1.6.37=hbc83047_0
139 | - libpq=12.9=h16c4e8d_3
140 | - libsodium=1.0.18=h36c2ea0_1
141 | - libstdcxx-ng=12.2.0=h46fd767_19
142 | - libtasn1=4.16.0=h27cfd23_0
143 | - libtiff=4.5.0=hecacb30_0
144 | - libunistring=0.9.10=h27cfd23_0
145 | - libuuid=1.41.5=h5eee18b_0
146 | - libwebp=1.2.4=h11a3e52_0
147 | - libwebp-base=1.2.4=h5eee18b_0
148 | - libxcb=1.15=h7f8727e_0
149 | - libxkbcommon=1.0.3=he3ba5ed_0
150 | - libxml2=2.9.14=h74e7548_0
151 | - libxslt=1.1.35=h4e12654_0
152 | - lz4-c=1.9.4=h6a678d5_0
153 | - markupsafe=2.1.2=py310h1fa729e_0
154 | - matplotlib=3.6.2=py310hff52083_0
155 | - matplotlib-base=3.6.2=py310h945d387_0
156 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0
157 | - mistune=2.0.5=pyhd8ed1ab_0
158 | - mkl=2021.4.0=h06a4308_640
159 | - mkl-service=2.4.0=py310h7f8727e_0
160 | - mkl_fft=1.3.1=py310hd6ae3a3_0
161 | - mkl_random=1.2.2=py310h00e6091_0
162 | - munkres=1.1.4=pyh9f0ad1d_0
163 | - nbclassic=0.5.1=pyhd8ed1ab_0
164 | - nbclient=0.7.2=pyhd8ed1ab_0
165 | - nbconvert=7.2.9=pyhd8ed1ab_0
166 | - nbconvert-core=7.2.9=pyhd8ed1ab_0
167 | - nbconvert-pandoc=7.2.9=pyhd8ed1ab_0
168 | - nbformat=5.7.3=pyhd8ed1ab_0
169 | - ncurses=6.3=h5eee18b_3
170 | - nest-asyncio=1.5.6=pyhd8ed1ab_0
171 | - nettle=3.7.3=hbbd107a_1
172 | - notebook=6.5.2=pyha770c72_1
173 | - notebook-shim=0.2.2=pyhd8ed1ab_0
174 | - nsight-compute=2022.4.0.15=0
175 | - nspr=4.33=h295c915_0
176 | - nss=3.74=h0370c37_0
177 | - numpy=1.23.5=py310hd5efca6_0
178 | - numpy-base=1.23.5=py310h8e6c178_0
179 | - openh264=2.1.1=h4ff587b_0
180 | - openssl=1.1.1t=h0b41bf4_0
181 | - packaging=23.0=pyhd8ed1ab_0
182 | - pandoc=2.19.2=ha770c72_0
183 | - pandocfilters=1.5.0=pyhd8ed1ab_0
184 | - parso=0.8.3=pyhd8ed1ab_0
185 | - pcre=8.45=h9c3ff4c_0
186 | - pexpect=4.8.0=pyh1a96a4e_2
187 | - pickleshare=0.7.5=py_1003
188 | - pillow=9.3.0=py310h6a678d5_2
189 | - pip=22.3.1=py310h06a4308_0
190 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
191 | - ply=3.11=py_1
192 | - prometheus_client=0.16.0=pyhd8ed1ab_0
193 | - prompt-toolkit=3.0.36=pyha770c72_0
194 | - psutil=5.9.4=py310h5764c6d_0
195 | - ptyprocess=0.7.0=pyhd3deb0d_0
196 | - pure_eval=0.2.2=pyhd8ed1ab_0
197 | - pycparser=2.21=pyhd3eb1b0_0
198 | - pygments=2.14.0=pyhd8ed1ab_0
199 | - pyopenssl=22.0.0=pyhd3eb1b0_0
200 | - pyparsing=3.0.9=pyhd8ed1ab_0
201 | - pyqt=5.15.7=py310h6a678d5_1
202 | - pyrsistent=0.19.3=py310h1fa729e_0
203 | - pysocks=1.7.1=py310h06a4308_0
204 | - python=3.10.8=h7a1cb2a_1
205 | - python-dateutil=2.8.2=pyhd8ed1ab_0
206 | - python-fastjsonschema=2.16.2=pyhd8ed1ab_0
207 | - python_abi=3.10=2_cp310
208 | - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0
209 | - pytorch-cuda=11.7=h67b0de4_1
210 | - pytorch-mutex=1.0=cuda
211 | - pyzmq=25.0.0=py310h059b190_0
212 | - qt-main=5.15.2=h327a75a_7
213 | - qt-webengine=5.15.9=hd2b0992_4
214 | - qtwebkit=5.212=h4eab89a_4
215 | - readline=8.2=h5eee18b_0
216 | - requests=2.28.1=py310h06a4308_0
217 | - send2trash=1.8.0=pyhd8ed1ab_0
218 | - setuptools=65.6.3=py310h06a4308_0
219 | - sip=6.6.2=py310h6a678d5_0
220 | - six=1.16.0=pyhd3eb1b0_1
221 | - sniffio=1.3.0=pyhd8ed1ab_0
222 | - soupsieve=2.3.2.post1=pyhd8ed1ab_0
223 | - sqlite=3.40.1=h5082296_0
224 | - stack_data=0.6.2=pyhd8ed1ab_0
225 | - terminado=0.17.1=pyh41d4057_0
226 | - tinycss2=1.2.1=pyhd8ed1ab_0
227 | - tk=8.6.12=h1ccaba5_0
228 | - toml=0.10.2=pyhd8ed1ab_0
229 | - torchaudio=0.13.1=py310_cu117
230 | - torchvision=0.14.1=py310_cu117
231 | - tornado=6.2=py310h5eee18b_0
232 | - traitlets=5.7.1=py310h06a4308_0
233 | - typing_extensions=4.4.0=py310h06a4308_0
234 | - tzdata=2022g=h04d1e81_0
235 | - urllib3=1.26.14=py310h06a4308_0
236 | - wcwidth=0.2.6=pyhd8ed1ab_0
237 | - webencodings=0.5.1=py_1
238 | - websocket-client=1.5.1=pyhd8ed1ab_0
239 | - wheel=0.37.1=pyhd3eb1b0_0
240 | - xz=5.2.8=h5eee18b_0
241 | - zeromq=4.3.4=h9c3ff4c_1
242 | - zipp=3.11.0=py310h06a4308_0
243 | - zlib=1.2.13=h5eee18b_0
244 | - zstd=1.5.2=ha4553b6_0
245 | - pip:
246 | - accelerate==0.18.0
247 | - click==8.1.3
248 | - diffusers==0.10.2
249 | - filelock==3.10.4
250 | - huggingface-hub==0.13.3
251 | - joblib==1.2.0
252 | - mypy-extensions==1.0.0
253 | - nltk==3.8.1
254 | - opencv-python==4.7.0.72
255 | - pyqt5-sip==12.11.0
256 | - pyrallis==0.3.1
257 | - pyyaml==6.0
258 | - regex==2023.3.23
259 | - scikit-learn==1.2.2
260 | - scipy==1.10.1
261 | - threadpoolctl==3.1.0
262 | - tokenizers==0.13.2
263 | - tqdm==4.65.0
264 | - transformers==4.25.1
265 | - typing-inspect==0.8.0
266 | prefix: /home/danielgaribi/miniconda3/envs/lpm
267 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from dataclasses import dataclass, field
4 | from typing import List
5 |
6 | import pyrallis
7 | import torch
8 | from torch.utils.data import DataLoader
9 | from torchvision.transforms import ToTensor
10 | from torchvision.utils import save_image
11 | from tqdm import tqdm
12 |
13 | from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
14 | generate_original_image
15 | from src.null_text_inversion import invert_image
16 | from src.prompt_mixing import PromptMixing
17 | from src.prompt_to_prompt_controllers import AttentionStore, AttentionReplace
18 | from src.prompt_utils import get_proxy_prompts
19 |
20 |
21 | def save_args_dict(args, similar_words):
22 | exp_path = os.path.join(args.exp_dir, args.prompt.replace(' ', '-'), f"seed={args.seed}_{args.exp_name}")
23 | os.makedirs(exp_path, exist_ok=True)
24 |
25 | args_dict = vars(args)
26 | args_dict['similar_words'] = similar_words
27 | with open(os.path.join(exp_path, "opt.json"), 'w') as fp:
28 | json.dump(args_dict, fp, sort_keys=True, indent=4)
29 |
30 | return exp_path
31 |
32 | def setup(args):
33 | ldm_stable = get_stable_diffusion_model(args)
34 | ldm_stable_config = get_stable_diffusion_config(args)
35 | return ldm_stable, ldm_stable_config
36 |
37 |
38 | def main(ldm_stable, ldm_stable_config, args):
39 |
40 | similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
41 | exp_path = save_args_dict(args, similar_words)
42 |
43 | images = []
44 | x_t = None
45 | uncond_embeddings = None
46 |
47 | if args.real_image_path != "":
48 | ldm_stable, ldm_stable_config = setup(args)
49 | x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path)
50 |
51 | image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings)
52 | save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg")
53 | save_image(torch.from_numpy(orig_mask).float(), f"{exp_path}/{similar_words[0]}_mask.jpg")
54 | images.append(image[0])
55 |
56 | object_of_interest_index = args.prompt.split().index('{word}') + 1
57 | pm = PromptMixing(args, object_of_interest_index, average_attention)
58 |
59 | do_other_obj_self_attn_masking = len(args.objects_to_preserve) > 0 and args.end_preserved_obj_self_attn_masking > 0
60 | do_self_or_cross_attn_inject = args.cross_attn_inject_steps != 0.0 or args.self_attn_inject_steps != 0.0
61 | if do_other_obj_self_attn_masking:
62 | print("Do self attn other obj masking")
63 | if do_self_or_cross_attn_inject:
64 | print(f'Do self attn inject for {args.self_attn_inject_steps} steps')
65 | print(f'Do cross attn inject for {args.cross_attn_inject_steps} steps')
66 |
67 | another_prompts_dataloader = DataLoader(another_prompts[1:], batch_size=args.batch_size, shuffle=False)
68 |
69 | for another_prompt_batch in tqdm(another_prompts_dataloader):
70 | batch_size = len(another_prompt_batch["word"])
71 | batch_prompts = prompts * batch_size
72 | batch_another_prompt = another_prompt_batch["prompt"]
73 | if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking:
74 | batch_prompts.append(prompts[0])
75 | batch_another_prompt.insert(0, prompts[0])
76 |
77 | if do_self_or_cross_attn_inject:
78 | controller = AttentionReplace(batch_another_prompt, ldm_stable.tokenizer, ldm_stable.device,
79 | ldm_stable_config["low_resource"], ldm_stable_config["num_diffusion_steps"],
80 | cross_replace_steps=args.cross_attn_inject_steps,
81 | self_replace_steps=args.self_attn_inject_steps)
82 | else:
83 | controller = AttentionStore(ldm_stable_config["low_resource"])
84 |
85 | diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, prompt_mixing=pm)
86 | with torch.no_grad():
87 | image, x_t, _, mask = diffusion_model_wrapper.forward(batch_prompts, latent=x_t, other_prompt=batch_another_prompt,
88 | post_background=args.background_post_process, orig_all_latents=orig_all_latents,
89 | orig_mask=orig_mask, uncond_embeddings=uncond_embeddings)
90 |
91 | for i in range(batch_size):
92 | image_index = i + 1 if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking else i
93 | save_image(ToTensor()(image[image_index]), f"{exp_path}/{another_prompt_batch['word'][i]}.jpg")
94 | if mask is not None:
95 | save_image(torch.from_numpy(mask).float(), f"{exp_path}/{another_prompt_batch['word'][i]}_mask.jpg")
96 | images.append(image[image_index])
97 |
98 | images = [ToTensor()(image) for image in images]
99 | save_image(images, f"{exp_path}/grid.jpg", nrow=min(max([i for i in range(2, 8) if len(images) % i == 0]), 8))
100 | return images, similar_words
101 |
102 |
103 | @dataclass
104 | class LPMConfig:
105 |
106 | # general config
107 | seed: int = 10
108 | batch_size: int = 1
109 | exp_dir: str = "results"
110 | exp_name: str = ""
111 | display_images: bool = False
112 | gpu_id: int = 0
113 |
114 | # Stable Diffusion config
115 | auth_token: str = ""
116 | low_resource: bool = True
117 | num_diffusion_steps: int = 50
118 | guidance_scale: float = 7.5
119 | max_num_words: int = 77
120 |
121 | # prompt-mixing
122 | prompt: str = "a {word} in the field eats an apple"
123 | object_of_interest: str = "snake" # The object for which we generate variations
124 | proxy_words: List[str] = field(default_factory=lambda :[]) # Leave empty for automatic proxy words
125 | number_of_variations: int = 20
126 | start_prompt_range: int = 7 # Number of steps to begin prompt-mixing
127 | end_prompt_range: int = 17 # Number of steps to finish prompt-mixing
128 |
129 | # attention based shape localization
130 | objects_to_preserve: List[str] = field(default_factory=lambda :[]) # Objects for which apply attention based shape localization
131 | remove_obj_from_self_mask: bool = True # If set to True, removes the object of interest from the self-attention mask
132 | obj_pixels_injection_threshold: float = 0.05
133 | end_preserved_obj_self_attn_masking: int = 40
134 |
135 | # real image
136 | real_image_path: str = ""
137 |
138 | # controllable background preservation
139 | background_post_process: bool = True
140 | background_nouns: List[str] = field(default_factory=lambda :[]) # Objects to take from the original image in addition to the background
141 | num_segments: int = 5 # Number of clusters for the segmentation
142 | background_segment_threshold: float = 0.3 # Threshold for the segments labeling
143 | background_blend_timestep: int = 35 # Number of steps before background blending
144 |
145 | # other
146 | cross_attn_inject_steps: float = 0.0
147 | self_attn_inject_steps: float = 0.0
148 |
149 |
150 | if __name__ == '__main__':
151 | args = pyrallis.parse(config_class=LPMConfig)
152 |
153 | print(args)
154 |
155 | stable, stable_config = setup(args)
156 | main(stable, stable_config, args)
157 |
--------------------------------------------------------------------------------
/real_images/lamp_simple.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/real_images/lamp_simple.png
--------------------------------------------------------------------------------
/real_images/rinon_cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orpatashnik/local-prompt-mixing/4bc225db751ae755604677cb2cdbb7bd8e732fe8/real_images/rinon_cat.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.10.2
2 | opencv-python==4.7.0.72
3 | pyrallis==0.3.1
4 | torch==1.13.1
5 | torchvision==0.14.1
6 | transformers==4.25.1
7 | nltk==3.8.1
8 | scipy
9 | scikit-learn
10 | accelerate
11 | gradio
--------------------------------------------------------------------------------
/run_segmentation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 |
4 | import torch
5 | from torchvision.transforms import ToTensor
6 | from torchvision.utils import save_image
7 | import matplotlib.pyplot as plt
8 |
9 | from src.attention_based_segmentation import Segmentor
10 | from src.diffusion_model_wrapper import get_stable_diffusion_model, get_stable_diffusion_config, DiffusionModelWrapper
11 | from src.null_text_inversion import invert_image
12 | from src.prompt_to_prompt_controllers import AttentionStore
13 |
14 | @dataclass
15 | class SegmentationConfig:
16 | seed: int = 1111
17 | gpu_id: int = 0
18 | real_image_path: str = "real_images/rinon_cat.jpg"
19 | auth_token: str = ""
20 | low_resource: bool = True
21 | num_diffusion_steps: int = 50
22 | guidance_scale: float = 7.5
23 | max_num_words: int = 77
24 | prompt: str = "a cat in a basket"
25 | exp_path: str = "segmentation_results"
26 |
27 | num_segments: int = 5
28 | background_segment_threshold: float = 0.35
29 |
30 | if __name__ == '__main__':
31 | args = SegmentationConfig()
32 | os.makedirs(args.exp_path, exist_ok=True)
33 | ldm_stable = get_stable_diffusion_model(args)
34 | ldm_stable_config = get_stable_diffusion_config(args)
35 |
36 | x_t = None
37 | uncond_embeddings = None
38 | if args.real_image_path != "":
39 | x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, [args.prompt], args.exp_path)
40 |
41 | g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed)
42 | controller = AttentionStore(ldm_stable_config["low_resource"])
43 | diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu)
44 | image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward([args.prompt],
45 | latent=x_t,
46 | uncond_embeddings=uncond_embeddings)
47 | segmentor = Segmentor(controller, [args.prompt], args.num_segments, args.background_segment_threshold)
48 | clusters = segmentor.cluster()
49 | cluster2noun = segmentor.cluster2noun(clusters)
50 |
51 | save_image(ToTensor()(image[0]), f"{args.exp_path}/image_rec.jpg")
52 | plt.imshow(clusters)
53 | plt.axis('off')
54 | plt.savefig(f"{args.exp_path}/segmentation.jpg", bbox_inches='tight', pad_inches=0)
55 |
--------------------------------------------------------------------------------
/src/attention_based_segmentation.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | from sklearn.cluster import KMeans
3 | import numpy as np
4 |
5 | from src.attention_utils import aggregate_attention
6 |
7 |
8 | class Segmentor:
9 |
10 | def __init__(self, controller, prompts, num_segments, background_segment_threshold, res=32, background_nouns=[]):
11 | self.controller = controller
12 | self.prompts = prompts
13 | self.num_segments = num_segments
14 | self.background_segment_threshold = background_segment_threshold
15 | self.resolution = res
16 | self.background_nouns = background_nouns
17 |
18 | self.self_attention = aggregate_attention(controller, res=32, from_where=("up", "down"), prompts=prompts,
19 | is_cross=False, select=len(prompts) - 1)
20 | self.cross_attention = aggregate_attention(controller, res=16, from_where=("up", "down"), prompts=prompts,
21 | is_cross=True, select=len(prompts) - 1)
22 | tokenized_prompt = nltk.word_tokenize(prompts[-1])
23 | self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt)) if pos[:2] == 'NN']
24 |
25 | def __call__(self, *args, **kwargs):
26 | clusters = self.cluster()
27 | cluster2noun = self.cluster2noun(clusters)
28 | return cluster2noun
29 |
30 | def cluster(self):
31 | np.random.seed(1)
32 | resolution = self.self_attention.shape[0]
33 | attn = self.self_attention.cpu().numpy().reshape(resolution ** 2, resolution ** 2)
34 | kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(attn)
35 | clusters = kmeans.labels_
36 | clusters = clusters.reshape(resolution, resolution)
37 | return clusters
38 |
39 | def cluster2noun(self, clusters):
40 | result = {}
41 | nouns_indices = [index for (index, word) in self.nouns]
42 | nouns_maps = self.cross_attention.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]]
43 | normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1)
44 | for i in range(nouns_maps.shape[-1]):
45 | curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1)
46 | normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
47 | for c in range(self.num_segments):
48 | cluster_mask = np.zeros_like(clusters)
49 | cluster_mask[clusters == c] = 1
50 | score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
51 | scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
52 | result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > self.background_segment_threshold else "BG"
53 | return result
54 |
55 | def get_background_mask(self, obj_token_index):
56 | clusters = self.cluster()
57 | cluster2noun = self.cluster2noun(clusters)
58 | mask = clusters.copy()
59 | obj_segments = [c for c in cluster2noun if cluster2noun[c][0] == obj_token_index - 1]
60 | background_segments = [c for c in cluster2noun if cluster2noun[c] == "BG" or cluster2noun[c][1] in self.background_nouns]
61 | for c in range(self.num_segments):
62 | if c in background_segments and c not in obj_segments:
63 | mask[clusters == c] = 0
64 | else:
65 | mask[clusters == c] = 1
66 | return mask
67 |
68 |
--------------------------------------------------------------------------------
/src/attention_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Tuple, List
4 | from cv2 import putText, getTextSize, FONT_HERSHEY_SIMPLEX
5 | # import matplotlib.pyplot as plt
6 | from PIL import Image
7 |
8 | from src.prompt_to_prompt_controllers import AttentionStore
9 |
10 | def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int, prompts):
11 | out = []
12 | attention_maps = attention_store.get_average_attention()
13 | num_pixels = res ** 2
14 | for location in from_where:
15 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
16 | if item.shape[1] == num_pixels:
17 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
18 | out.append(cross_maps)
19 | out = torch.cat(out, dim=0)
20 | out = out.sum(0) / out.shape[0]
21 | return out.cpu()
22 |
23 |
24 | def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], prompts, tokenizer, select: int = 0):
25 | tokens = tokenizer.encode(prompts[select])
26 | decoder = tokenizer.decode
27 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select, prompts)
28 | images = []
29 | for i in range(len(tokens)):
30 | image = attention_maps[:, :, i]
31 | image = 255 * image / image.max()
32 | image = image.unsqueeze(-1).expand(*image.shape, 3)
33 | image = image.numpy().astype(np.uint8)
34 | image = np.array(Image.fromarray(image).resize((256, 256)))
35 | image = text_under_image(image, decoder(int(tokens[i])))
36 | images.append(image)
37 | view_images(np.stack(images, axis=0))
38 |
39 |
40 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
41 | max_com=10, select: int = 0):
42 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape(
43 | (res ** 2, res ** 2))
44 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
45 | images = []
46 | for i in range(max_com):
47 | image = vh[i].reshape(res, res)
48 | image = image - image.min()
49 | image = 255 * image / image.max()
50 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
51 | image = Image.fromarray(image).resize((256, 256))
52 | image = np.array(image)
53 | images.append(image)
54 | view_images(np.concatenate(images, axis=1))
55 |
56 |
57 | def view_images(images, num_rows=1, offset_ratio=0.02):
58 | if type(images) is list:
59 | num_empty = len(images) % num_rows
60 | elif images.ndim == 4:
61 | num_empty = images.shape[0] % num_rows
62 | else:
63 | images = [images]
64 | num_empty = 0
65 |
66 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
67 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
68 | num_items = len(images)
69 |
70 | h, w, c = images[0].shape
71 | offset = int(h * offset_ratio)
72 | num_cols = num_items // num_rows
73 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
74 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
75 | for i in range(num_rows):
76 | for j in range(num_cols):
77 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
78 | i * num_cols + j]
79 |
80 | pil_img = Image.fromarray(image_)
81 | display(pil_img)
82 |
83 |
84 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
85 | h, w, c = image.shape
86 | offset = int(h * .2)
87 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
88 | font = FONT_HERSHEY_SIMPLEX
89 | img[:h] = image
90 | textsize = getTextSize(text, font, 1, 2)[0]
91 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
92 | putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
93 | return img
94 |
95 |
96 | def display(image):
97 | global display_index
98 | plt.imshow(image)
99 | plt.show()
100 |
--------------------------------------------------------------------------------
/src/diffusion_model_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | import numpy as np
4 | import torch
5 | from cv2 import dilate
6 | from diffusers import DDIMScheduler, StableDiffusionPipeline
7 | from tqdm import tqdm
8 |
9 | from src.attention_based_segmentation import Segmentor
10 | from src.attention_utils import show_cross_attention
11 | from src.prompt_to_prompt_controllers import DummyController, AttentionStore
12 |
13 |
14 | def get_stable_diffusion_model(args):
15 | device = torch.device(f'cuda:{args.gpu_id}') if torch.cuda.is_available() else torch.device('cpu')
16 | if args.real_image_path != "":
17 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
18 | ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token, scheduler=scheduler).to(device)
19 | else:
20 | ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token).to(device)
21 |
22 | return ldm_stable
23 |
24 | def get_stable_diffusion_config(args):
25 | return {
26 | "low_resource": args.low_resource,
27 | "num_diffusion_steps": args.num_diffusion_steps,
28 | "guidance_scale": args.guidance_scale,
29 | "max_num_words": args.max_num_words
30 | }
31 |
32 |
33 | def generate_original_image(args, ldm_stable, ldm_stable_config, prompts, latent, uncond_embeddings):
34 | g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed)
35 | controller = AttentionStore(ldm_stable_config["low_resource"])
36 | diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu)
37 | image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward(prompts,
38 | latent=latent,
39 | uncond_embeddings=uncond_embeddings)
40 | orig_mask = Segmentor(controller, prompts, args.num_segments, args.background_segment_threshold, background_nouns=args.background_nouns)\
41 | .get_background_mask(args.prompt.split(' ').index("{word}") + 1)
42 | average_attention = controller.get_average_attention()
43 | return image, x_t, orig_all_latents, orig_mask, average_attention
44 |
45 |
46 | class DiffusionModelWrapper:
47 | def __init__(self, args, model, model_config, controller=None, prompt_mixing=None, generator=None):
48 | self.args = args
49 | self.model = model
50 | self.model_config = model_config
51 | self.controller = controller
52 | if self.controller is None:
53 | self.controller = DummyController()
54 | self.prompt_mixing = prompt_mixing
55 | self.device = model.device
56 | self.generator = generator
57 |
58 | self.height = 512
59 | self.width = 512
60 |
61 | self.diff_step = 0
62 | self.register_attention_control()
63 |
64 |
65 | def diffusion_step(self, latents, context, t, other_context=None):
66 | if self.model_config["low_resource"]:
67 | self.uncond_pred = True
68 | noise_pred_uncond = self.model.unet(latents, t, encoder_hidden_states=(context[0], None))["sample"]
69 | self.uncond_pred = False
70 | noise_prediction_text = self.model.unet(latents, t, encoder_hidden_states=(context[1], other_context))["sample"]
71 | else:
72 | latents_input = torch.cat([latents] * 2)
73 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=(context, other_context))["sample"]
74 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
75 | noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_prediction_text - noise_pred_uncond)
76 | latents = self.model.scheduler.step(noise_pred, t, latents)["prev_sample"]
77 | latents = self.controller.step_callback(latents)
78 | return latents
79 |
80 |
81 | def latent2image(self, latents):
82 | latents = 1 / 0.18215 * latents
83 | image = self.model.vae.decode(latents)['sample']
84 | image = (image / 2 + 0.5).clamp(0, 1)
85 | image = image.cpu().permute(0, 2, 3, 1).numpy()
86 | image = (image * 255).astype(np.uint8)
87 | return image
88 |
89 |
90 | def init_latent(self, latent, batch_size):
91 | if latent is None:
92 | latent = torch.randn(
93 | (1, self.model.unet.in_channels, self.height // 8, self.width // 8),
94 | generator=self.generator, device=self.model.device
95 | )
96 | latents = latent.expand(batch_size, self.model.unet.in_channels, self.height // 8, self.width // 8).to(self.device)
97 | return latent, latents
98 |
99 |
100 | def register_attention_control(self):
101 | def ca_forward(model_self, place_in_unet):
102 | to_out = model_self.to_out
103 | if type(to_out) is torch.nn.modules.container.ModuleList:
104 | to_out = model_self.to_out[0]
105 | else:
106 | to_out = model_self.to_out
107 |
108 | def forward(x, context=None, mask=None):
109 | batch_size, sequence_length, dim = x.shape
110 | h = model_self.heads
111 | q = model_self.to_q(x)
112 | is_cross = context is not None
113 | context = context if is_cross else (x, None)
114 |
115 | k = model_self.to_k(context[0])
116 | if is_cross and self.prompt_mixing is not None:
117 | v_context = self.prompt_mixing.get_context_for_v(self.diff_step, context[0], context[1])
118 | v = model_self.to_v(v_context)
119 | else:
120 | v = model_self.to_v(context[0])
121 |
122 | q = model_self.reshape_heads_to_batch_dim(q)
123 | k = model_self.reshape_heads_to_batch_dim(k)
124 | v = model_self.reshape_heads_to_batch_dim(v)
125 |
126 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * model_self.scale
127 |
128 | if mask is not None:
129 | mask = mask.reshape(batch_size, -1)
130 | max_neg_value = -torch.finfo(sim.dtype).max
131 | mask = mask[:, None, :].repeat(h, 1, 1)
132 | sim.masked_fill_(~mask, max_neg_value)
133 |
134 | # attention, what we cannot get enough of
135 | attn = sim.softmax(dim=-1)
136 | if self.enbale_attn_controller_changes:
137 | attn = self.controller(attn, is_cross, place_in_unet)
138 |
139 | if is_cross and self.prompt_mixing is not None and context[1] is not None:
140 | attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size)
141 |
142 | if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None:
143 | attn = self.prompt_mixing.get_self_attn(self, self.diff_step, attn, place_in_unet, batch_size)
144 |
145 | out = torch.einsum("b i j, b j d -> b i d", attn, v)
146 | out = model_self.reshape_batch_dim_to_heads(out)
147 | return to_out(out)
148 |
149 | return forward
150 |
151 | def register_recr(net_, count, place_in_unet):
152 | if net_.__class__.__name__ == 'CrossAttention':
153 | net_.forward = ca_forward(net_, place_in_unet)
154 | return count + 1
155 | elif hasattr(net_, 'children'):
156 | for net__ in net_.children():
157 | count = register_recr(net__, count, place_in_unet)
158 | return count
159 |
160 | cross_att_count = 0
161 | sub_nets = self.model.unet.named_children()
162 | for net in sub_nets:
163 | if "down" in net[0]:
164 | cross_att_count += register_recr(net[1], 0, "down")
165 | elif "up" in net[0]:
166 | cross_att_count += register_recr(net[1], 0, "up")
167 | elif "mid" in net[0]:
168 | cross_att_count += register_recr(net[1], 0, "mid")
169 | self.controller.num_att_layers = cross_att_count
170 |
171 |
172 | def get_text_embedding(self, prompt: List[str], max_length=None, truncation=True):
173 | text_input = self.model.tokenizer(
174 | prompt,
175 | padding="max_length",
176 | max_length=self.model.tokenizer.model_max_length if max_length is None else max_length,
177 | truncation=truncation,
178 | return_tensors="pt",
179 | )
180 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0]
181 | max_length = text_input.input_ids.shape[-1]
182 | return text_embeddings, max_length
183 |
184 |
185 | @torch.no_grad()
186 | def forward(self, prompt: List[str], latent: Optional[torch.FloatTensor] = None,
187 | other_prompt: List[str] = None, post_background = False, orig_all_latents = None, orig_mask = None,
188 | uncond_embeddings=None, start_time=51, return_type='image'):
189 | self.enbale_attn_controller_changes = True
190 | batch_size = len(prompt)
191 |
192 | text_embeddings, max_length = self.get_text_embedding(prompt)
193 | if uncond_embeddings is None:
194 | uncond_embeddings_, _ = self.get_text_embedding([""] * batch_size, max_length=max_length, truncation=False)
195 | else:
196 | uncond_embeddings_ = None
197 |
198 | other_context = None
199 | if other_prompt is not None:
200 | other_text_embeddings, _ = self.get_text_embedding(other_prompt)
201 | other_context = other_text_embeddings
202 |
203 | latent, latents = self.init_latent(latent, batch_size)
204 |
205 | # set timesteps
206 | self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"])
207 | all_latents = []
208 |
209 | object_mask = None
210 | self.diff_step = 0
211 | for i, t in enumerate(tqdm(self.model.scheduler.timesteps[-start_time:])):
212 | if uncond_embeddings_ is None:
213 | context = [uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]
214 | else:
215 | context = [uncond_embeddings_, text_embeddings]
216 | if not self.model_config["low_resource"]:
217 | context = torch.cat(context)
218 |
219 | self.down_cross_index = 0
220 | self.mid_cross_index = 0
221 | self.up_cross_index = 0
222 | latents = self.diffusion_step(latents, context, t, other_context)
223 |
224 | if post_background and self.diff_step == self.args.background_blend_timestep:
225 | object_mask = Segmentor(self.controller,
226 | prompt,
227 | self.args.num_segments,
228 | self.args.background_segment_threshold,
229 | background_nouns=self.args.background_nouns)\
230 | .get_background_mask(self.args.prompt.split(' ').index("{word}") + 1)
231 | self.enbale_attn_controller_changes = False
232 | mask = object_mask.astype(np.bool8) + orig_mask.astype(np.bool8)
233 | mask = torch.from_numpy(mask).float().cuda()
234 | shape = (1, 1, mask.shape[0], mask.shape[1])
235 | mask = torch.nn.Upsample(size=(64, 64), mode='nearest')(mask.view(shape))
236 | mask_eroded = dilate(mask.cpu().numpy()[0, 0], np.ones((3, 3), np.uint8), iterations=1)
237 | mask = torch.from_numpy(mask_eroded).float().cuda().view(1, 1, 64, 64)
238 | latents = mask * latents + (1 - mask) * orig_all_latents[self.diff_step]
239 |
240 | all_latents.append(latents)
241 | self.diff_step += 1
242 |
243 | if return_type == 'image':
244 | image = self.latent2image(latents)
245 | else:
246 | image = latents
247 |
248 | return image, latent, all_latents, object_mask
249 |
250 |
251 | def show_last_cross_attention(self, res: int, from_where: List[str], prompts, select: int = 0):
252 | show_cross_attention(self.controller, res, from_where, prompts, tokenizer=self.model.tokenizer, select=select)
--------------------------------------------------------------------------------
/src/null_text_inversion.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from torchvision.transforms import ToTensor
4 | from torchvision.utils import save_image
5 | from tqdm import tqdm
6 | import torch
7 | from torch.optim.adam import Adam
8 | import torch.nn.functional as nnf
9 | import numpy as np
10 | from PIL import Image
11 |
12 |
13 | def load_512(image_path, left=0, right=0, top=0, bottom=0):
14 | if type(image_path) is str:
15 | image = np.array(Image.open(image_path))[:, :, :3]
16 | else:
17 | image = image_path
18 | h, w, c = image.shape
19 | left = min(left, w-1)
20 | right = min(right, w - left - 1)
21 | top = min(top, h - left - 1)
22 | bottom = min(bottom, h - top - 1)
23 | image = image[top:h-bottom, left:w-right]
24 | h, w, c = image.shape
25 | if h < w:
26 | offset = (w - h) // 2
27 | image = image[:, offset:offset + h]
28 | elif w < h:
29 | offset = (h - w) // 2
30 | image = image[offset:offset + w]
31 | image = np.array(Image.fromarray(image).resize((512, 512)))
32 | return image
33 |
34 |
35 | def invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path):
36 | print("Start null text inversion")
37 | null_inversion = NullInversion(ldm_stable, ldm_stable_config)
38 | (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(args.real_image_path, prompts[0], offsets=(0,0,0,0), verbose=True)
39 | save_image(ToTensor()(image_gt), f"{exp_path}/real_image.jpg")
40 | save_image(ToTensor()(image_enc), f"{exp_path}/image_enc.jpg")
41 | print("End null text inversion")
42 | return x_t, uncond_embeddings
43 |
44 |
45 | class NullInversion:
46 |
47 | def __init__(self, model, model_config):
48 | self.model = model
49 | self.model_config = model_config
50 | self.tokenizer = self.model.tokenizer
51 | self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"])
52 | self.prompt = None
53 | self.context = None
54 |
55 |
56 | def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
57 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
58 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
59 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
60 | beta_prod_t = 1 - alpha_prod_t
61 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
62 | pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
63 | prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
64 | return prev_sample
65 |
66 | def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
67 | timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
68 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
69 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
70 | beta_prod_t = 1 - alpha_prod_t
71 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
72 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
73 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
74 | return next_sample
75 |
76 | def get_noise_pred_single(self, latents, t, context):
77 | noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
78 | return noise_pred
79 |
80 | def get_noise_pred(self, latents, t, is_forward=True, context=None):
81 | latents_input = torch.cat([latents] * 2)
82 | if context is None:
83 | context = self.context
84 | guidance_scale = 1 if is_forward else self.model_config["guidance_scale"]
85 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
86 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
87 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
88 | if is_forward:
89 | latents = self.next_step(noise_pred, t, latents)
90 | else:
91 | latents = self.prev_step(noise_pred, t, latents)
92 | return latents
93 |
94 | @torch.no_grad()
95 | def latent2image(self, latents, return_type='np'):
96 | latents = 1 / 0.18215 * latents.detach()
97 | image = self.model.vae.decode(latents)['sample']
98 | if return_type == 'np':
99 | image = (image / 2 + 0.5).clamp(0, 1)
100 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
101 | image = (image * 255).astype(np.uint8)
102 | return image
103 |
104 | @torch.no_grad()
105 | def image2latent(self, image):
106 | with torch.no_grad():
107 | if type(image) is Image:
108 | image = np.array(image)
109 | if type(image) is torch.Tensor and image.dim() == 4:
110 | latents = image
111 | else:
112 | image = torch.from_numpy(image).float() / 127.5 - 1
113 | image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device)
114 | latents = self.model.vae.encode(image)['latent_dist'].mean
115 | latents = latents * 0.18215
116 | return latents
117 |
118 | @torch.no_grad()
119 | def init_prompt(self, prompt: str):
120 | uncond_input = self.model.tokenizer(
121 | [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
122 | return_tensors="pt"
123 | )
124 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
125 | text_input = self.model.tokenizer(
126 | [prompt],
127 | padding="max_length",
128 | max_length=self.model.tokenizer.model_max_length,
129 | truncation=True,
130 | return_tensors="pt",
131 | )
132 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
133 | self.context = torch.cat([uncond_embeddings, text_embeddings])
134 | self.prompt = prompt
135 |
136 | @torch.no_grad()
137 | def ddim_loop(self, latent):
138 | uncond_embeddings, cond_embeddings = self.context.chunk(2)
139 | all_latent = [latent]
140 | latent = latent.clone().detach()
141 | for i in tqdm(range(self.model_config["num_diffusion_steps"])):
142 | t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
143 | noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
144 | latent = self.next_step(noise_pred, t, latent)
145 | all_latent.append(latent)
146 | return all_latent
147 |
148 | @property
149 | def scheduler(self):
150 | return self.model.scheduler
151 |
152 | @torch.no_grad()
153 | def ddim_inversion(self, image):
154 | latent = self.image2latent(image)
155 | image_rec = self.latent2image(latent)
156 | ddim_latents = self.ddim_loop(latent)
157 | return image_rec, ddim_latents
158 |
159 | def null_optimization(self, latents, num_inner_steps, epsilon):
160 | uncond_embeddings, cond_embeddings = self.context.chunk(2)
161 | uncond_embeddings_list = []
162 | latent_cur = latents[-1]
163 | with tqdm(total=num_inner_steps * (self.model_config["num_diffusion_steps"])) as bar:
164 | for i in range(self.model_config["num_diffusion_steps"]):
165 | uncond_embeddings = uncond_embeddings.clone().detach()
166 | uncond_embeddings.requires_grad = True
167 | optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
168 | latent_prev = latents[len(latents) - i - 2]
169 | t = self.model.scheduler.timesteps[i]
170 | with torch.no_grad():
171 | noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
172 | for j in range(num_inner_steps):
173 | noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
174 | noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond)
175 | latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
176 | loss = nnf.mse_loss(latents_prev_rec, latent_prev)
177 | optimizer.zero_grad()
178 | loss.backward()
179 | optimizer.step()
180 | loss_item = loss.item()
181 | bar.update()
182 | if loss_item < epsilon + i * 2e-5:
183 | break
184 | bar.update(num_inner_steps - j - 1)
185 | uncond_embeddings_list.append(uncond_embeddings[:1].detach())
186 | with torch.no_grad():
187 | context = torch.cat([uncond_embeddings, cond_embeddings])
188 | latent_cur = self.get_noise_pred(latent_cur, t, False, context)
189 | # bar.close()
190 | return uncond_embeddings_list
191 |
192 | def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
193 | self.init_prompt(prompt)
194 | image_gt = load_512(image_path, *offsets)
195 | if verbose:
196 | print("DDIM inversion...")
197 | image_rec, ddim_latents = self.ddim_inversion(image_gt)
198 | if verbose:
199 | print("Null-text optimization...")
200 | uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
201 | return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
202 |
--------------------------------------------------------------------------------
/src/prompt_mixing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scipy.signal import medfilt2d
3 |
4 | class PromptMixing:
5 | def __init__(self, args, object_of_interest_index, avg_cross_attn=None):
6 | self.object_of_interest_index = object_of_interest_index
7 | self.objects_to_preserve = [args.prompt.split().index(o) + 1 for o in args.objects_to_preserve]
8 | self.obj_pixels_injection_threshold = args.obj_pixels_injection_threshold
9 |
10 | self.start_other_prompt_range = args.start_prompt_range
11 | self.end_other_prompt_range = args.end_prompt_range
12 |
13 | self.start_cross_attn_replace_range = args.num_diffusion_steps
14 | self.end_cross_attn_replace_range = args.num_diffusion_steps
15 |
16 | self.start_self_attn_replace_range = 0
17 | self.end_self_attn_replace_range = args.end_preserved_obj_self_attn_masking
18 | self.remove_obj_from_self_mask = args.remove_obj_from_self_mask
19 | self.avg_cross_attn = avg_cross_attn
20 |
21 | self.low_resource = args.low_resource
22 |
23 | def get_context_for_v(self, t, context, other_context):
24 | if other_context is not None and \
25 | self.start_other_prompt_range <= t < self.end_other_prompt_range:
26 | if self.low_resource:
27 | return other_context
28 | else:
29 | v_context = context.clone()
30 | # first half of context is for the uncoditioned image
31 | v_context[v_context.shape[0]//2:] = other_context
32 | return v_context
33 | else:
34 | return context
35 |
36 | def get_cross_attn(self, diffusion_model_wrapper, t, attn, place_in_unet, batch_size):
37 | if self.start_cross_attn_replace_range <= t < self.end_cross_attn_replace_range:
38 | if self.low_resource:
39 | attn[:,:,self.object_of_interest_index] = 0.2 * torch.from_numpy(medfilt2d(attn[:, :, self.object_of_interest_index].cpu().numpy(), kernel_size=3)).to(attn.device) + \
40 | 0.8 * attn[:, :, self.object_of_interest_index]
41 | else:
42 | # first half of attn maps is for the uncoditioned image
43 | min_h = attn.shape[0] // 2
44 | attn[min_h:, :, self.object_of_interest_index] = 0.2 * torch.from_numpy(medfilt2d(attn[min_h:, :, self.object_of_interest_index].cpu().numpy(), kernel_size=3)).to(attn.device) + \
45 | 0.8 * attn[min_h:, :, self.object_of_interest_index]
46 | return attn
47 |
48 | def get_self_attn(self, diffusion_model_wrapper, t, attn, place_in_unet, batch_size):
49 | if attn.shape[1] <= 32 ** 2 and \
50 | self.avg_cross_attn is not None and \
51 | self.start_self_attn_replace_range <= t < self.end_self_attn_replace_range:
52 |
53 | key = f"{place_in_unet}_cross"
54 | attn_index = getattr(diffusion_model_wrapper, f'{key}_index')
55 | cr = self.avg_cross_attn[key][attn_index]
56 | setattr(diffusion_model_wrapper, f'{key}_index', attn_index+1)
57 |
58 | if self.low_resource:
59 | attn = self.mask_self_attn_patches(attn, cr, batch_size)
60 | else:
61 | # first half of attn maps is for the uncoditioned image
62 | attn[attn.shape[0]//2:] = self.mask_self_attn_patches(attn[attn.shape[0]//2:], cr, batch_size//2)
63 |
64 | return attn
65 |
66 | def mask_self_attn_patches(self, self_attn, cross_attn, batch_size):
67 | h = self_attn.shape[0] // batch_size
68 | tokens = self.objects_to_preserve
69 | obj_token = self.object_of_interest_index
70 |
71 | normalized_cross_attn = cross_attn - cross_attn.min()
72 | normalized_cross_attn /= normalized_cross_attn.max()
73 |
74 | mask = torch.zeros_like(self_attn[0])
75 | for tk in tokens:
76 | mask_tk_in = torch.unique((normalized_cross_attn[:,:,tk] > self.obj_pixels_injection_threshold).nonzero(as_tuple=True)[1])
77 | mask[mask_tk_in, :] = 1
78 | mask[:, mask_tk_in] = 1
79 |
80 | if self.remove_obj_from_self_mask:
81 | obj_patches = torch.unique((normalized_cross_attn[:,:,obj_token] > self.obj_pixels_injection_threshold).nonzero(as_tuple=True)[1])
82 | mask[obj_patches, :] = 0
83 | mask[:, obj_patches] = 0
84 |
85 | self_attn[h:] = self_attn[h:] * (1 - mask) + self_attn[:h].repeat(batch_size - 1, 1, 1) * mask
86 | return self_attn
--------------------------------------------------------------------------------
/src/prompt_to_prompt_controllers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import abc
4 | from typing import Optional, Union, Tuple, Dict
5 | import src.seq_aligner as seq_aligner
6 |
7 |
8 | class AttentionControl(abc.ABC):
9 |
10 | def step_callback(self, x_t):
11 | return x_t
12 |
13 | def between_steps(self):
14 | return
15 |
16 | @property
17 | def num_uncond_att_layers(self):
18 | return self.num_att_layers if self.low_resource else 0
19 |
20 | @abc.abstractmethod
21 | def forward(self, attn, is_cross: bool, place_in_unet: str):
22 | raise NotImplementedError
23 |
24 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
25 | if self.cur_att_layer >= self.num_uncond_att_layers:
26 | if self.low_resource:
27 | attn = self.forward(attn, is_cross, place_in_unet)
28 | else:
29 | h = attn.shape[0]
30 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
31 | self.cur_att_layer += 1
32 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
33 | self.cur_att_layer = 0
34 | self.cur_step += 1
35 | self.between_steps()
36 | return attn
37 |
38 | def reset(self):
39 | self.cur_step = 0
40 | self.cur_att_layer = 0
41 |
42 | def __init__(self, low_resource):
43 | self.cur_step = 0
44 | self.num_att_layers = -1
45 | self.cur_att_layer = 0
46 | self.low_resource = low_resource
47 |
48 |
49 | class EmptyControl(AttentionControl):
50 |
51 | def forward(self, attn, is_cross: bool, place_in_unet: str):
52 | return attn
53 |
54 |
55 | class DummyController:
56 | def __call__(self, *args):
57 | return args[0]
58 |
59 | def __init__(self):
60 | self.num_att_layers = 0
61 |
62 |
63 | class AttentionStore(AttentionControl):
64 |
65 | @staticmethod
66 | def get_empty_store():
67 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
68 | "down_self": [], "mid_self": [], "up_self": []}
69 |
70 | def forward(self, attn, is_cross: bool, place_in_unet: str):
71 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
72 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead
73 | self.step_store[key].append(attn)
74 | return attn
75 |
76 | def between_steps(self):
77 | if len(self.attention_store) == 0:
78 | self.attention_store = self.step_store
79 | else:
80 | for key in self.attention_store:
81 | for i in range(len(self.attention_store[key])):
82 | self.attention_store[key][i] += self.step_store[key][i]
83 | self.step_store = self.get_empty_store()
84 |
85 | def get_average_attention(self):
86 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
87 | self.attention_store}
88 | return average_attention
89 |
90 | def reset(self):
91 | super(AttentionStore, self).reset()
92 | self.step_store = self.get_empty_store()
93 | self.attention_store = {}
94 |
95 | def __init__(self, low_resource):
96 | super(AttentionStore, self).__init__(low_resource)
97 | self.step_store = self.get_empty_store()
98 | self.attention_store = {}
99 |
100 |
101 | class AttentionControlEdit(AttentionStore, abc.ABC):
102 |
103 | def step_callback(self, x_t):
104 | return x_t
105 |
106 | def replace_self_attention(self, attn_base, att_replace):
107 | if att_replace.shape[2] <= 16 ** 2:
108 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
109 | else:
110 | return att_replace
111 |
112 | @abc.abstractmethod
113 | def replace_cross_attention(self, attn_base, att_replace):
114 | raise NotImplementedError
115 |
116 | def forward(self, attn, is_cross: bool, place_in_unet: str):
117 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
118 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
119 | h = attn.shape[0] // (self.batch_size)
120 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
121 | attn_base, attn_repalce = attn[0], attn[1:]
122 | if is_cross:
123 | alpha_words = self.cross_replace_alpha[self.cur_step]
124 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
125 | 1 - alpha_words) * attn_repalce
126 | attn[1:] = attn_repalce_new
127 | else:
128 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
129 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
130 | return attn
131 |
132 | def __init__(self, prompts, tokenizer, device, low_resource, num_steps: int,
133 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
134 | self_replace_steps: Union[float, Tuple[float, float]]):
135 | super(AttentionControlEdit, self).__init__(low_resource)
136 | self.batch_size = len(prompts)
137 | self.tokenizer = tokenizer
138 | self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
139 | self.tokenizer).to(device)
140 | if type(self_replace_steps) is float:
141 | self_replace_steps = 0, self_replace_steps
142 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
143 |
144 |
145 | class AttentionReplace(AttentionControlEdit):
146 |
147 | def replace_cross_attention(self, attn_base, att_replace):
148 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper.to(attn_base.dtype))
149 |
150 | def __init__(self, prompts, tokenizer, device, low_resource, num_steps: int, cross_replace_steps: float, self_replace_steps: float):
151 | super(AttentionReplace, self).__init__(prompts, tokenizer, device, low_resource, num_steps, cross_replace_steps, self_replace_steps)
152 | self.mapper = seq_aligner.get_replacement_mapper(prompts, self.tokenizer).to(device)
153 |
154 |
155 | def get_word_inds(text: str, word_place: int, tokenizer):
156 | split_text = text.split(" ")
157 | if type(word_place) is str:
158 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
159 | elif type(word_place) is int:
160 | word_place = [word_place]
161 | out = []
162 | if len(word_place) > 0:
163 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
164 | cur_len, ptr = 0, 0
165 |
166 | for i in range(len(words_encode)):
167 | cur_len += len(words_encode[i])
168 | if ptr in word_place:
169 | out.append(i + 1)
170 | if cur_len >= len(split_text[ptr]):
171 | ptr += 1
172 | cur_len = 0
173 | return np.array(out)
174 |
175 |
176 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None):
177 | if type(bounds) is float:
178 | bounds = 0, bounds
179 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
180 | if word_inds is None:
181 | word_inds = torch.arange(alpha.shape[2])
182 | alpha[: start, prompt_ind, word_inds] = 0
183 | alpha[start: end, prompt_ind, word_inds] = 1
184 | alpha[end:, prompt_ind, word_inds] = 0
185 | return alpha
186 |
187 |
188 | def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
189 | tokenizer, max_num_words=77):
190 | if type(cross_replace_steps) is not dict:
191 | cross_replace_steps = {"default_": cross_replace_steps}
192 | if "default_" not in cross_replace_steps:
193 | cross_replace_steps["default_"] = (0., 1.)
194 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
195 | for i in range(len(prompts) - 1):
196 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
197 | i)
198 | for key, item in cross_replace_steps.items():
199 | if key != "default_":
200 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
201 | for i, ind in enumerate(inds):
202 | if len(ind) > 0:
203 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
204 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words
205 | return alpha_time_words
--------------------------------------------------------------------------------
/src/prompt_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 |
7 | def get_topk_similar_words(model, prompt, base_word, vocab, k=30):
8 | text_input = model.tokenizer(
9 | [prompt.format(word=base_word)],
10 | padding="max_length",
11 | max_length=model.tokenizer.model_max_length,
12 | truncation=True,
13 | return_tensors="pt",
14 | )
15 | with torch.no_grad():
16 | encoder_output = model.text_encoder(text_input.input_ids.to(model.device))
17 | full_prompt_embedding = encoder_output.pooler_output
18 | full_prompt_embedding = full_prompt_embedding / full_prompt_embedding.norm(p=2, dim=-1, keepdim=True)
19 |
20 | prompts = [prompt.format(word=word) for word in vocab]
21 | batch_size = 1000
22 | all_prompts_embeddings = []
23 | for i in tqdm(range(0, len(prompts), batch_size)):
24 | curr_prompts = prompts[i:i + batch_size]
25 | with torch.no_grad():
26 | text_input = model.tokenizer(
27 | curr_prompts,
28 | padding="max_length",
29 | max_length=model.tokenizer.model_max_length,
30 | truncation=True,
31 | return_tensors="pt",
32 | )
33 | curr_embeddings = model.text_encoder(text_input.input_ids.to(model.device)).pooler_output
34 | all_prompts_embeddings.append(curr_embeddings)
35 |
36 | all_prompts_embeddings = torch.cat(all_prompts_embeddings)
37 | all_prompts_embeddings = all_prompts_embeddings / all_prompts_embeddings.norm(p=2, dim=-1, keepdim=True)
38 | prompts_similarities = all_prompts_embeddings.matmul(full_prompt_embedding.view(-1, 1))
39 | sorted_prompts_similarities = np.flip(prompts_similarities.cpu().numpy().reshape(-1).argsort())
40 |
41 | print(f"prompt: {prompt}")
42 | print(f"initial word: {base_word}")
43 | print(f"TOP {k} SIMILAR WORDS:")
44 | similar_words = [vocab[index] for index in sorted_prompts_similarities[:k]]
45 | print(similar_words)
46 | return similar_words
47 |
48 | def get_proxy_words(args, ldm_stable):
49 | if len(args.proxy_words) > 0:
50 | return [args.object_of_interest] + args.proxy_words
51 | vocab = list(json.load(open("vocab.json")).keys())
52 | vocab = [word for word in vocab if word.isalpha() and len(word) > 1]
53 | filtered_vocab = get_topk_similar_words(ldm_stable, "a photo of a {word}", args.object_of_interest, vocab, k=50)
54 | proxy_words = get_topk_similar_words(ldm_stable, args.prompt, args.object_of_interest, filtered_vocab, k=args.number_of_variations)
55 | if proxy_words[0] != args.object_of_interest:
56 | proxy_words = [args.object_of_interest] + proxy_words
57 |
58 | return proxy_words
59 |
60 | def get_proxy_prompts(args, ldm_stable):
61 | proxy_words = get_proxy_words(args, ldm_stable)
62 | prompts = [args.prompt.format(word=args.object_of_interest)]
63 | proxy_prompts = [{"word": word, "prompt": args.prompt.format(word=word)} for word in proxy_words]
64 | return proxy_words, prompts, proxy_prompts
--------------------------------------------------------------------------------
/src/seq_aligner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import numpy as np
16 |
17 |
18 | class ScoreParams:
19 |
20 | def __init__(self, gap, match, mismatch):
21 | self.gap = gap
22 | self.match = match
23 | self.mismatch = mismatch
24 |
25 | def mis_match_char(self, x, y):
26 | if x != y:
27 | return self.mismatch
28 | else:
29 | return self.match
30 |
31 |
32 | def get_matrix(size_x, size_y, gap):
33 | matrix = []
34 | for i in range(len(size_x) + 1):
35 | sub_matrix = []
36 | for j in range(len(size_y) + 1):
37 | sub_matrix.append(0)
38 | matrix.append(sub_matrix)
39 | for j in range(1, len(size_y) + 1):
40 | matrix[0][j] = j*gap
41 | for i in range(1, len(size_x) + 1):
42 | matrix[i][0] = i*gap
43 | return matrix
44 |
45 |
46 | def get_matrix(size_x, size_y, gap):
47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap
49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap
50 | return matrix
51 |
52 |
53 | def get_traceback_matrix(size_x, size_y):
54 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
55 | matrix[0, 1:] = 1
56 | matrix[1:, 0] = 2
57 | matrix[0, 0] = 4
58 | return matrix
59 |
60 |
61 | def global_align(x, y, score):
62 | matrix = get_matrix(len(x), len(y), score.gap)
63 | trace_back = get_traceback_matrix(len(x), len(y))
64 | for i in range(1, len(x) + 1):
65 | for j in range(1, len(y) + 1):
66 | left = matrix[i, j - 1] + score.gap
67 | up = matrix[i - 1, j] + score.gap
68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
69 | matrix[i, j] = max(left, up, diag)
70 | if matrix[i, j] == left:
71 | trace_back[i, j] = 1
72 | elif matrix[i, j] == up:
73 | trace_back[i, j] = 2
74 | else:
75 | trace_back[i, j] = 3
76 | return matrix, trace_back
77 |
78 |
79 | def get_aligned_sequences(x, y, trace_back):
80 | x_seq = []
81 | y_seq = []
82 | i = len(x)
83 | j = len(y)
84 | mapper_y_to_x = []
85 | while i > 0 or j > 0:
86 | if trace_back[i, j] == 3:
87 | x_seq.append(x[i-1])
88 | y_seq.append(y[j-1])
89 | i = i-1
90 | j = j-1
91 | mapper_y_to_x.append((j, i))
92 | elif trace_back[i][j] == 1:
93 | x_seq.append('-')
94 | y_seq.append(y[j-1])
95 | j = j-1
96 | mapper_y_to_x.append((j, -1))
97 | elif trace_back[i][j] == 2:
98 | x_seq.append(x[i-1])
99 | y_seq.append('-')
100 | i = i-1
101 | elif trace_back[i][j] == 4:
102 | break
103 | mapper_y_to_x.reverse()
104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
105 |
106 |
107 | def get_mapper(x: str, y: str, tokenizer, max_len=77):
108 | x_seq = tokenizer.encode(x)
109 | y_seq = tokenizer.encode(y)
110 | score = ScoreParams(0, 1, -1)
111 | matrix, trace_back = global_align(x_seq, y_seq, score)
112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
113 | alphas = torch.ones(max_len)
114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
115 | mapper = torch.zeros(max_len, dtype=torch.int64)
116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
118 | return mapper, alphas
119 |
120 |
121 | def get_refinement_mapper(prompts, tokenizer, max_len=77):
122 | x_seq = prompts[0]
123 | mappers, alphas = [], []
124 | for i in range(1, len(prompts)):
125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
126 | mappers.append(mapper)
127 | alphas.append(alpha)
128 | return torch.stack(mappers), torch.stack(alphas)
129 |
130 |
131 | def get_word_inds(text: str, word_place: int, tokenizer):
132 | split_text = text.split(" ")
133 | if type(word_place) is str:
134 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
135 | elif type(word_place) is int:
136 | word_place = [word_place]
137 | out = []
138 | if len(word_place) > 0:
139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
140 | cur_len, ptr = 0, 0
141 |
142 | for i in range(len(words_encode)):
143 | cur_len += len(words_encode[i])
144 | if ptr in word_place:
145 | out.append(i + 1)
146 | if cur_len >= len(split_text[ptr]):
147 | ptr += 1
148 | cur_len = 0
149 | return np.array(out)
150 |
151 |
152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
153 | words_x = x.split(' ')
154 | words_y = y.split(' ')
155 | if len(words_x) != len(words_y):
156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
161 | mapper = np.zeros((max_len, max_len))
162 | i = j = 0
163 | cur_inds = 0
164 | while i < max_len and j < max_len:
165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
167 | if len(inds_source_) == len(inds_target_):
168 | mapper[inds_source_, inds_target_] = 1
169 | else:
170 | ratio = 1 / len(inds_target_)
171 | for i_t in inds_target_:
172 | mapper[inds_source_, i_t] = ratio
173 | cur_inds += 1
174 | i += len(inds_source_)
175 | j += len(inds_target_)
176 | elif cur_inds < len(inds_source):
177 | mapper[i, j] = 1
178 | i += 1
179 | j += 1
180 | else:
181 | mapper[j, j] = 1
182 | i += 1
183 | j += 1
184 |
185 | return torch.from_numpy(mapper).float()
186 |
187 |
188 | def get_replacement_mapper(prompts, tokenizer, max_len=77):
189 | x_seq = prompts[0]
190 | mappers = []
191 | for i in range(1, len(prompts)):
192 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
193 | mappers.append(mapper)
194 | return torch.stack(mappers)
195 |
196 |
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | h1 {
2 | text-align: center;
3 | }
4 |
--------------------------------------------------------------------------------