├── .gitignore
├── LICENSE
├── README.md
├── assets
├── eclipse_solar_eclipse.png
├── example.png
└── results.png
├── main.py
├── requirements.txt
└── src
├── pipelines
├── __init__.py
├── pipeline_kandinsky_prior.py
└── pipeline_unclip.py
└── priors
├── __init__.py
└── prior_transformer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 eclipse-t2i
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ##
[CVPR 2024] ECLIPSE: Revisiting the Text-to-Image Prior for Effecient Image Generation
2 |
3 |
4 |

5 |

6 |

7 |
8 |

9 |
10 |
11 |
12 | ---
13 |
14 | This repository contains the inference code for our paper, ECLIPSE.
15 | We show how to utilize the pre-trained ECLIPSE text-to-image prior associated with diffusion image decoders such as Karlo and Kandinsky.
16 |
17 | - ECLIPSE presents the tiny prior learning strategy that compresses the previous prior models from 1 billion parameters down to 33 million parameters.
18 | - Additionally, ECLIPSE prior is trained on a mere 5 million image-text (alt-text) pairs.
19 |
20 | > **_News:_** Checkout our latest work, [λ-ECLIPSE](https://eclipse-t2i.github.io/Lambda-ECLIPSE/) extending the T2I priors for effecient zero-shot multi-subject driven text-to-image generations.
21 |
22 |
23 | **Please follow the below steps to run the inference locally.**
24 |
25 | ---
26 |
27 | **Qualitative Comparisons:**
28 | 
29 |
30 |
31 | **Quantitative Comparisons:**
32 | 
33 |
34 | ## TODOs:
35 |
36 | - [x] ~~Release ECLIPSE priors for Kandinsky v2.2 and Karlo-v1-alpha.~~
37 | - [x] ~~Release the demo.~~
38 | - [ ] Release ECLIPSE prior with Kandinsky v2.2 LCM decoder. (soon!)
39 | - [ ] Release ECLIPSE prior training code. (will be released in seperate repository)
40 |
41 |
42 | ## Setup
43 |
44 | ### Installation
45 | ```bash
46 | git clone git@github.com:eclipse-t2i/eclipse-inference.git
47 |
48 | conda create -p ./venv python=3.9
49 | pip install -r requirements.txt
50 | ```
51 |
52 | ### Demo
53 | ```bash
54 | conda activate ./venv
55 | gradio main.py
56 | ```
57 |
58 | ## Run Inference
59 |
60 | This repository supports two pre-trained image decoders: [Karlo-v1-alpha](https://huggingface.co/kakaobrain/karlo-v1-alpha) and [Kandinsky-v2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder).
61 |
62 | **Note:** ECLIPSE prior is not a diffusion model -- while image decoders are.
63 |
64 |
65 | ### Kandinsky Inference
66 | ```python
67 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer
68 | from src.pipelines.pipeline_kandinsky_prior import KandinskyPriorPipeline
69 | from src.priors.prior_transformer import PriorTransformer
70 | from diffusers import DiffusionPipeline
71 |
72 | text_encoder = (
73 | CLIPTextModelWithProjection.from_pretrained(
74 | "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
75 | projection_dim=1280,
76 | torch_dtype=torch.float32,
77 | )
78 | )
79 |
80 | tokenizer = CLIPTokenizer.from_pretrained(
81 | "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
82 | )
83 |
84 | prior = PriorTransformer.from_pretrained("ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior")
85 | pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior",
86 | prior=prior,
87 | text_encoder=text_encoder,
88 | tokenizer=tokenizer,
89 | ).to("cuda")
90 |
91 | pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder").to("cuda")
92 |
93 | prompt = "black apples in the basket"
94 | image_embeds, negative_image_embeds = pipe_prior(prompt).to_tuple()
95 | images = pipe(
96 | num_inference_steps=50,
97 | image_embeds=image_embeds,
98 | negative_image_embeds=negative_image_embeds,
99 | ).images
100 |
101 | images[0]
102 | ```
103 |
104 |
105 | ### Karlo Inference
106 | ```python
107 | from src.pipelines.pipeline_unclip import UnCLIPPipeline
108 | from src.priors.prior_transformer import PriorTransformer
109 |
110 | prior = PriorTransformer.from_pretrained("ECLIPSE-Community/ECLIPSE_Karlo_Prior")
111 | pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", prior=prior).to("cuda")
112 |
113 | prompt="black apples in the basket"
114 | images = pipe(prompt, decoder_guidance_scale=7.5).images
115 |
116 | images[0]
117 | ```
118 |
119 | # Acknowledgement
120 |
121 | We would like to acknoweldge excellent open-source text-to-image models (Kalro and Kandinsky) without them this work would not have been possible. Also, we thank HuggingFace for streamlining the T2I models.
122 |
--------------------------------------------------------------------------------
/assets/eclipse_solar_eclipse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/assets/eclipse_solar_eclipse.png
--------------------------------------------------------------------------------
/assets/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/assets/example.png
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/assets/results.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from PIL import Image
3 |
4 | import torch
5 |
6 | from torchvision import transforms
7 | from transformers import (
8 | CLIPProcessor,
9 | CLIPModel,
10 | CLIPTokenizer,
11 | CLIPTextModelWithProjection,
12 | CLIPVisionModelWithProjection,
13 | CLIPFeatureExtractor,
14 | )
15 |
16 | import math
17 | from typing import List
18 | from PIL import Image, ImageChops
19 | import numpy as np
20 | import torch
21 |
22 | from diffusers import UnCLIPPipeline
23 |
24 | # from diffusers.utils.torch_utils import randn_tensor
25 |
26 | from transformers import CLIPTokenizer
27 |
28 | from src.priors.prior_transformer import (
29 | PriorTransformer,
30 | ) # original huggingface prior transformer without time conditioning
31 | from src.pipelines.pipeline_kandinsky_prior import KandinskyPriorPipeline
32 |
33 | from diffusers import DiffusionPipeline
34 |
35 |
36 | __DEVICE__ = "cpu"
37 | if torch.cuda.is_available():
38 | __DEVICE__ = "cuda"
39 |
40 | class Ours:
41 | def __init__(self, device):
42 | text_encoder = (
43 | CLIPTextModelWithProjection.from_pretrained(
44 | "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
45 | projection_dim=1280,
46 | torch_dtype=torch.float16,
47 | )
48 | .eval()
49 | .requires_grad_(False)
50 | )
51 |
52 | tokenizer = CLIPTokenizer.from_pretrained(
53 | "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
54 | )
55 |
56 | prior = PriorTransformer.from_pretrained(
57 | "ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior",
58 | torch_dtype=torch.float16,
59 | )
60 |
61 | self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
62 | "kandinsky-community/kandinsky-2-2-prior",
63 | prior=prior,
64 | text_encoder=text_encoder,
65 | tokenizer=tokenizer,
66 | torch_dtype=torch.float16,
67 | ).to(device)
68 |
69 | self.pipe = DiffusionPipeline.from_pretrained(
70 | "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
71 | ).to(device)
72 |
73 | def inference(self, text, negative_text, steps, guidance_scale):
74 | gen_images = []
75 | for i in range(1):
76 | image_emb, negative_image_emb = self.pipe_prior(
77 | text, negative_prompt=negative_text
78 | ).to_tuple()
79 | image = self.pipe(
80 | image_embeds=image_emb,
81 | negative_image_embeds=negative_image_emb,
82 | num_inference_steps=steps,
83 | guidance_scale=guidance_scale,
84 | ).images
85 | gen_images.append(image[0])
86 | return gen_images
87 |
88 |
89 | selected_model = Ours(device=__DEVICE__)
90 |
91 |
92 | def get_images(text, negative_text, steps, guidance_scale):
93 | images = selected_model.inference(text, negative_text, steps, guidance_scale)
94 | new_images = []
95 | for img in images:
96 | new_images.append(img)
97 | return new_images[0]
98 |
99 |
100 | with gr.Blocks() as demo:
101 | gr.Markdown(
102 | """ECLIPSE: Revisiting the Text-to-Image Prior for Effecient Image Generation
103 |
104 | """
105 | )
106 |
107 | with gr.Group():
108 | with gr.Row():
109 | with gr.Column():
110 | text = gr.Textbox(
111 | label="Enter your prompt",
112 | show_label=False,
113 | max_lines=1,
114 | placeholder="Enter your prompt",
115 | elem_id="prompt-text-input",
116 | ).style(
117 | border=(True, False, True, True),
118 | rounded=(True, False, False, True),
119 | container=False,
120 | )
121 |
122 | with gr.Row():
123 | with gr.Column():
124 | negative_text = gr.Textbox(
125 | label="Enter your negative prompt",
126 | show_label=False,
127 | max_lines=1,
128 | placeholder="Enter your negative prompt",
129 | elem_id="prompt-text-input",
130 | ).style(
131 | border=(True, False, True, True),
132 | rounded=(True, False, False, True),
133 | container=False,
134 | )
135 |
136 | with gr.Row():
137 | steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=1)
138 | guidance_scale = gr.Slider(
139 | label="Guidance Scale", minimum=0, maximum=10, value=7.5, step=0.1
140 | )
141 |
142 | with gr.Row():
143 | btn = gr.Button(value="Generate Image", full_width=False)
144 |
145 | gallery = gr.Image(
146 | height=512, width=512, label="Generated images", show_label=True, elem_id="gallery"
147 | ).style(preview=False, columns=1)
148 |
149 | btn.click(
150 | get_images,
151 | inputs=[
152 | text,
153 | negative_text,
154 | steps,
155 | guidance_scale,
156 | ],
157 | outputs=gallery,
158 | )
159 | text.submit(
160 | get_images,
161 | inputs=[
162 | text,
163 | negative_text,
164 | steps,
165 | guidance_scale,
166 | ],
167 | outputs=gallery,
168 | )
169 | negative_text.submit(
170 | get_images,
171 | inputs=[
172 | text,
173 | negative_text,
174 | steps,
175 | guidance_scale,
176 | ],
177 | outputs=gallery,
178 | )
179 |
180 | with gr.Accordion(label="Ethics & Privacy", open=False):
181 | gr.HTML(
182 | """
183 |
Privacy
184 | We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI.
185 |
Biases and content acknowledgment
186 | This model will have the same biases as pre-trained CLIP model.
187 | """
188 | )
189 |
190 | if __name__ == "__main__":
191 | demo.queue(max_size=20).launch()
192 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.24.0
2 | datasets==2.14.6
3 | diffusers==0.20.2
4 | numpy==1.26.1
5 | packaging==23.2
6 | pandas_stubs==1.2.0.57
7 | Pillow==10.1.0
8 | torch==2.0.0
9 | torchvision==0.15.1
10 | tqdm==4.66.1
11 | transformers==4.34.1
12 | gradio
13 | jmespath
14 | opencv-python
15 | PyWavelet
16 | gradio==3.47.1
--------------------------------------------------------------------------------
/src/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/src/pipelines/__init__.py
--------------------------------------------------------------------------------
/src/pipelines/pipeline_kandinsky_prior.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Union
3 |
4 | import numpy as np
5 | import PIL
6 | import torch
7 | from transformers import (
8 | CLIPImageProcessor,
9 | CLIPTextModelWithProjection,
10 | CLIPTokenizer,
11 | CLIPVisionModelWithProjection,
12 | )
13 |
14 | from diffusers.models import PriorTransformer
15 | from diffusers.schedulers import UnCLIPScheduler
16 | from diffusers.utils import (
17 | BaseOutput,
18 | is_accelerate_available,
19 | is_accelerate_version,
20 | logging,
21 | randn_tensor,
22 | replace_example_docstring,
23 | )
24 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25 |
26 |
27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28 |
29 | EXAMPLE_DOC_STRING = """
30 | Examples:
31 | ```py
32 | >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
33 | >>> import torch
34 |
35 | >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior")
36 | >>> pipe_prior.to("cuda")
37 |
38 | >>> prompt = "red cat, 4k photo"
39 | >>> out = pipe_prior(prompt)
40 | >>> image_emb = out.image_embeds
41 | >>> negative_image_emb = out.negative_image_embeds
42 |
43 | >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
44 | >>> pipe.to("cuda")
45 |
46 | >>> image = pipe(
47 | ... prompt,
48 | ... image_embeds=image_emb,
49 | ... negative_image_embeds=negative_image_emb,
50 | ... height=768,
51 | ... width=768,
52 | ... num_inference_steps=100,
53 | ... ).images
54 |
55 | >>> image[0].save("cat.png")
56 | ```
57 | """
58 |
59 | EXAMPLE_INTERPOLATE_DOC_STRING = """
60 | Examples:
61 | ```py
62 | >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline
63 | >>> from diffusers.utils import load_image
64 | >>> import PIL
65 |
66 | >>> import torch
67 | >>> from torchvision import transforms
68 |
69 | >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
70 | ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
71 | ... )
72 | >>> pipe_prior.to("cuda")
73 |
74 | >>> img1 = load_image(
75 | ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
76 | ... "/kandinsky/cat.png"
77 | ... )
78 |
79 | >>> img2 = load_image(
80 | ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
81 | ... "/kandinsky/starry_night.jpeg"
82 | ... )
83 |
84 | >>> images_texts = ["a cat", img1, img2]
85 | >>> weights = [0.3, 0.3, 0.4]
86 | >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
87 |
88 | >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
89 | >>> pipe.to("cuda")
90 |
91 | >>> image = pipe(
92 | ... "",
93 | ... image_embeds=image_emb,
94 | ... negative_image_embeds=zero_image_emb,
95 | ... height=768,
96 | ... width=768,
97 | ... num_inference_steps=150,
98 | ... ).images[0]
99 |
100 | >>> image.save("starry_cat.png")
101 | ```
102 | """
103 |
104 |
105 | @dataclass
106 | class KandinskyPriorPipelineOutput(BaseOutput):
107 | """
108 | Output class for KandinskyPriorPipeline.
109 |
110 | Args:
111 | image_embeds (`torch.FloatTensor`)
112 | clip image embeddings for text prompt
113 | negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`)
114 | clip image embeddings for unconditional tokens
115 | """
116 |
117 | image_embeds: Union[torch.FloatTensor, np.ndarray]
118 | negative_image_embeds: Union[torch.FloatTensor, np.ndarray]
119 |
120 |
121 | class KandinskyPriorPipeline(DiffusionPipeline):
122 | """
123 | Pipeline for generating image prior for Kandinsky
124 |
125 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
126 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
127 |
128 | Args:
129 | prior ([`PriorTransformer`]):
130 | The canonincal unCLIP prior to approximate the image embedding from the text embedding.
131 | image_encoder ([`CLIPVisionModelWithProjection`]):
132 | Frozen image-encoder.
133 | text_encoder ([`CLIPTextModelWithProjection`]):
134 | Frozen text-encoder.
135 | tokenizer (`CLIPTokenizer`):
136 | Tokenizer of class
137 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
138 | scheduler ([`UnCLIPScheduler`]):
139 | A scheduler to be used in combination with `prior` to generate image embedding.
140 | """
141 |
142 | _exclude_from_cpu_offload = ["prior"]
143 |
144 | def __init__(
145 | self,
146 | prior: PriorTransformer,
147 | image_encoder: CLIPVisionModelWithProjection,
148 | text_encoder: CLIPTextModelWithProjection,
149 | tokenizer: CLIPTokenizer,
150 | scheduler: UnCLIPScheduler,
151 | image_processor: CLIPImageProcessor,
152 | ):
153 | super().__init__()
154 |
155 | self.register_modules(
156 | prior=prior,
157 | text_encoder=text_encoder,
158 | tokenizer=tokenizer,
159 | scheduler=scheduler,
160 | image_encoder=image_encoder,
161 | image_processor=image_processor,
162 | )
163 |
164 | @torch.no_grad()
165 | @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
166 | def interpolate(
167 | self,
168 | images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
169 | weights: List[float],
170 | num_images_per_prompt: int = 1,
171 | num_inference_steps: int = 25,
172 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
173 | latents: Optional[torch.FloatTensor] = None,
174 | negative_prior_prompt: Optional[str] = None,
175 | negative_prompt: str = "",
176 | guidance_scale: float = 4.0,
177 | device=None,
178 | ):
179 | """
180 | Function invoked when using the prior pipeline for interpolation.
181 |
182 | Args:
183 | images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
184 | list of prompts and images to guide the image generation.
185 | weights: (`List[float]`):
186 | list of weights for each condition in `images_and_prompts`
187 | num_images_per_prompt (`int`, *optional*, defaults to 1):
188 | The number of images to generate per prompt.
189 | num_inference_steps (`int`, *optional*, defaults to 25):
190 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
191 | expense of slower inference.
192 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
193 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
194 | to make generation deterministic.
195 | latents (`torch.FloatTensor`, *optional*):
196 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
197 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
198 | tensor will ge generated by sampling using the supplied random `generator`.
199 | negative_prior_prompt (`str`, *optional*):
200 | The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
201 | `guidance_scale` is less than `1`).
202 | negative_prompt (`str` or `List[str]`, *optional*):
203 | The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
204 | `guidance_scale` is less than `1`).
205 | guidance_scale (`float`, *optional*, defaults to 4.0):
206 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
207 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
208 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
209 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
210 | usually at the expense of lower image quality.
211 |
212 | Examples:
213 |
214 | Returns:
215 | [`KandinskyPriorPipelineOutput`] or `tuple`
216 | """
217 |
218 | device = device or self.device
219 |
220 | if len(images_and_prompts) != len(weights):
221 | raise ValueError(
222 | f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
223 | )
224 |
225 | image_embeddings = []
226 | for cond, weight in zip(images_and_prompts, weights):
227 | if isinstance(cond, str):
228 | image_emb = self(
229 | cond,
230 | num_inference_steps=num_inference_steps,
231 | num_images_per_prompt=num_images_per_prompt,
232 | generator=generator,
233 | latents=latents,
234 | negative_prompt=negative_prior_prompt,
235 | guidance_scale=guidance_scale,
236 | ).image_embeds
237 |
238 | elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
239 | if isinstance(cond, PIL.Image.Image):
240 | cond = (
241 | self.image_processor(cond, return_tensors="pt")
242 | .pixel_values[0]
243 | .unsqueeze(0)
244 | .to(dtype=self.image_encoder.dtype, device=device)
245 | )
246 |
247 | image_emb = self.image_encoder(cond)["image_embeds"]
248 |
249 | else:
250 | raise ValueError(
251 | f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
252 | )
253 |
254 | image_embeddings.append(image_emb * weight)
255 |
256 | image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True)
257 |
258 | out_zero = self(
259 | negative_prompt,
260 | num_inference_steps=num_inference_steps,
261 | num_images_per_prompt=num_images_per_prompt,
262 | generator=generator,
263 | latents=latents,
264 | negative_prompt=negative_prior_prompt,
265 | guidance_scale=guidance_scale,
266 | )
267 | zero_image_emb = (
268 | out_zero.negative_image_embeds
269 | if negative_prompt == ""
270 | else out_zero.image_embeds
271 | )
272 |
273 | return KandinskyPriorPipelineOutput(
274 | image_embeds=image_emb, negative_image_embeds=zero_image_emb
275 | )
276 |
277 | # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
278 | def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
279 | if latents is None:
280 | latents = randn_tensor(
281 | shape, generator=generator, device=device, dtype=dtype
282 | )
283 | else:
284 | if latents.shape != shape:
285 | raise ValueError(
286 | f"Unexpected latents shape, got {latents.shape}, expected {shape}"
287 | )
288 | latents = latents.to(device)
289 |
290 | latents = latents * scheduler.init_noise_sigma
291 | return latents
292 |
293 | def get_zero_embed(self, batch_size=1, device=None):
294 | device = device or self.device
295 | zero_img = torch.zeros(
296 | 1,
297 | 3,
298 | self.image_encoder.config.image_size,
299 | self.image_encoder.config.image_size,
300 | ).to(device=device, dtype=self.image_encoder.dtype)
301 | zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
302 | zero_image_emb = zero_image_emb.repeat(batch_size, 1)
303 | return zero_image_emb
304 |
305 | def _encode_prompt(
306 | self,
307 | prompt,
308 | device,
309 | num_images_per_prompt,
310 | do_classifier_free_guidance,
311 | negative_prompt=None,
312 | ):
313 | batch_size = len(prompt) if isinstance(prompt, list) else 1
314 | # get prompt text embeddings
315 | text_inputs = self.tokenizer(
316 | prompt,
317 | padding="max_length",
318 | max_length=self.tokenizer.model_max_length,
319 | truncation=True,
320 | return_tensors="pt",
321 | )
322 | text_input_ids = text_inputs.input_ids
323 | text_mask = text_inputs.attention_mask.bool().to(device)
324 |
325 | untruncated_ids = self.tokenizer(
326 | prompt, padding="longest", return_tensors="pt"
327 | ).input_ids
328 |
329 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
330 | text_input_ids, untruncated_ids
331 | ):
332 | removed_text = self.tokenizer.batch_decode(
333 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
334 | )
335 | logger.warning(
336 | "The following part of your input was truncated because CLIP can only handle sequences up to"
337 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
338 | )
339 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
340 |
341 | text_encoder_output = self.text_encoder(text_input_ids.to(device))
342 |
343 | prompt_embeds = text_encoder_output.text_embeds
344 | text_encoder_hidden_states = text_encoder_output.last_hidden_state
345 |
346 | prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
347 | text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(
348 | num_images_per_prompt, dim=0
349 | )
350 | text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
351 |
352 | if do_classifier_free_guidance:
353 | uncond_tokens: List[str]
354 | if negative_prompt is None:
355 | uncond_tokens = [""] * batch_size
356 | elif type(prompt) is not type(negative_prompt):
357 | raise TypeError(
358 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
359 | f" {type(prompt)}."
360 | )
361 | elif isinstance(negative_prompt, str):
362 | uncond_tokens = [negative_prompt]
363 | elif batch_size != len(negative_prompt):
364 | raise ValueError(
365 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
366 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
367 | " the batch size of `prompt`."
368 | )
369 | else:
370 | uncond_tokens = negative_prompt
371 |
372 | uncond_input = self.tokenizer(
373 | uncond_tokens,
374 | padding="max_length",
375 | max_length=self.tokenizer.model_max_length,
376 | truncation=True,
377 | return_tensors="pt",
378 | )
379 | uncond_text_mask = uncond_input.attention_mask.bool().to(device)
380 | negative_prompt_embeds_text_encoder_output = self.text_encoder(
381 | uncond_input.input_ids.to(device)
382 | )
383 |
384 | negative_prompt_embeds = (
385 | negative_prompt_embeds_text_encoder_output.text_embeds
386 | )
387 | uncond_text_encoder_hidden_states = (
388 | negative_prompt_embeds_text_encoder_output.last_hidden_state
389 | )
390 |
391 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
392 |
393 | seq_len = negative_prompt_embeds.shape[1]
394 | negative_prompt_embeds = negative_prompt_embeds.repeat(
395 | 1, num_images_per_prompt
396 | )
397 | negative_prompt_embeds = negative_prompt_embeds.view(
398 | batch_size * num_images_per_prompt, seq_len
399 | )
400 |
401 | seq_len = uncond_text_encoder_hidden_states.shape[1]
402 | uncond_text_encoder_hidden_states = (
403 | uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
404 | )
405 | uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
406 | batch_size * num_images_per_prompt, seq_len, -1
407 | )
408 | uncond_text_mask = uncond_text_mask.repeat_interleave(
409 | num_images_per_prompt, dim=0
410 | )
411 |
412 | # done duplicates
413 |
414 | # For classifier free guidance, we need to do two forward passes.
415 | # Here we concatenate the unconditional and text embeddings into a single batch
416 | # to avoid doing two forward passes
417 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
418 | text_encoder_hidden_states = torch.cat(
419 | [uncond_text_encoder_hidden_states, text_encoder_hidden_states]
420 | )
421 |
422 | text_mask = torch.cat([uncond_text_mask, text_mask])
423 |
424 | return prompt_embeds, text_encoder_hidden_states, text_mask
425 |
426 | def enable_model_cpu_offload(self, gpu_id=0):
427 | r"""
428 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
429 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
430 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
431 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
432 | """
433 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
434 | from accelerate import cpu_offload_with_hook
435 | else:
436 | raise ImportError(
437 | "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher."
438 | )
439 |
440 | device = torch.device(f"cuda:{gpu_id}")
441 |
442 | if self.device.type != "cpu":
443 | self.to("cpu", silence_dtype_warnings=True)
444 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
445 |
446 | hook = None
447 | for cpu_offloaded_model in [self.text_encoder, self.prior]:
448 | _, hook = cpu_offload_with_hook(
449 | cpu_offloaded_model, device, prev_module_hook=hook
450 | )
451 |
452 | # We'll offload the last model manually.
453 | self.prior_hook = hook
454 |
455 | _, hook = cpu_offload_with_hook(
456 | self.image_encoder, device, prev_module_hook=self.prior_hook
457 | )
458 |
459 | self.final_offload_hook = hook
460 |
461 | @torch.no_grad()
462 | @replace_example_docstring(EXAMPLE_DOC_STRING)
463 | def __call__(
464 | self,
465 | prompt: Union[str, List[str]],
466 | negative_prompt: Optional[Union[str, List[str]]] = None,
467 | num_images_per_prompt: int = 1,
468 | num_inference_steps: int = 25,
469 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
470 | latents: Optional[torch.FloatTensor] = None,
471 | guidance_scale: float = 4.0,
472 | output_type: Optional[str] = "pt",
473 | return_dict: bool = True,
474 | ):
475 | """
476 | Function invoked when calling the pipeline for generation.
477 |
478 | Args:
479 | prompt (`str` or `List[str]`):
480 | The prompt or prompts to guide the image generation.
481 | negative_prompt (`str` or `List[str]`, *optional*):
482 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
483 | if `guidance_scale` is less than `1`).
484 | num_images_per_prompt (`int`, *optional*, defaults to 1):
485 | The number of images to generate per prompt.
486 | num_inference_steps (`int`, *optional*, defaults to 25):
487 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
488 | expense of slower inference.
489 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
490 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
491 | to make generation deterministic.
492 | latents (`torch.FloatTensor`, *optional*):
493 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
494 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
495 | tensor will ge generated by sampling using the supplied random `generator`.
496 | guidance_scale (`float`, *optional*, defaults to 4.0):
497 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
498 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
499 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
500 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
501 | usually at the expense of lower image quality.
502 | output_type (`str`, *optional*, defaults to `"pt"`):
503 | The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
504 | (`torch.Tensor`).
505 | return_dict (`bool`, *optional*, defaults to `True`):
506 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
507 |
508 | Examples:
509 |
510 | Returns:
511 | [`KandinskyPriorPipelineOutput`] or `tuple`
512 | """
513 |
514 | if isinstance(prompt, str):
515 | prompt = [prompt]
516 | elif not isinstance(prompt, list):
517 | raise ValueError(
518 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
519 | )
520 |
521 | if isinstance(negative_prompt, str):
522 | negative_prompt = [negative_prompt]
523 | elif not isinstance(negative_prompt, list) and negative_prompt is not None:
524 | raise ValueError(
525 | f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}"
526 | )
527 |
528 | # if the negative prompt is defined we double the batch size to
529 | # directly retrieve the negative prompt embedding
530 | if negative_prompt is not None:
531 | prompt = prompt + negative_prompt
532 | negative_prompt = 2 * negative_prompt
533 |
534 | device = self._execution_device
535 |
536 | batch_size = len(prompt)
537 | batch_size = batch_size * num_images_per_prompt
538 |
539 | prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
540 | prompt, device, num_images_per_prompt, False, negative_prompt
541 | )
542 |
543 | hidden_states = randn_tensor(
544 | (batch_size, prompt_embeds.shape[-1]),
545 | device=prompt_embeds.device,
546 | dtype=prompt_embeds.dtype,
547 | generator=generator,
548 | )
549 |
550 | latents = self.prior(
551 | hidden_states,
552 | proj_embedding=prompt_embeds,
553 | encoder_hidden_states=text_encoder_hidden_states,
554 | attention_mask=text_mask,
555 | ).predicted_image_embedding
556 |
557 | image_embeddings = latents
558 |
559 | # if negative prompt has been defined, we retrieve split the image embedding into two
560 | if negative_prompt is None:
561 | zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
562 |
563 | if (
564 | hasattr(self, "final_offload_hook")
565 | and self.final_offload_hook is not None
566 | ):
567 | self.final_offload_hook.offload()
568 | else:
569 | image_embeddings, zero_embeds = image_embeddings.chunk(2)
570 |
571 | if (
572 | hasattr(self, "final_offload_hook")
573 | and self.final_offload_hook is not None
574 | ):
575 | self.prior_hook.offload()
576 |
577 | if output_type not in ["pt", "np"]:
578 | raise ValueError(
579 | f"Only the output types `pt` and `np` are supported not output_type={output_type}"
580 | )
581 |
582 | if output_type == "np":
583 | image_embeddings = image_embeddings.cpu().numpy()
584 | zero_embeds = zero_embeds.cpu().numpy()
585 |
586 | if not return_dict:
587 | return (image_embeddings, zero_embeds)
588 |
589 | return KandinskyPriorPipelineOutput(
590 | image_embeds=image_embeddings, negative_image_embeds=zero_embeds
591 | )
592 |
--------------------------------------------------------------------------------
/src/pipelines/pipeline_unclip.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append("..")
4 |
5 | import inspect
6 | from typing import List, Optional, Tuple, Union
7 |
8 | import torch
9 | from torch.nn import functional as F
10 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer
11 | from transformers.models.clip.modeling_clip import CLIPTextModelOutput
12 |
13 | from diffusers.models import UNet2DConditionModel, UNet2DModel
14 | from diffusers.schedulers import UnCLIPScheduler
15 | from diffusers.utils import logging, randn_tensor
16 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
17 | from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
18 |
19 |
20 | from diffusers.models import PriorTransformer
21 |
22 |
23 | import torch
24 | from torchvision.transforms import ToPILImage
25 |
26 | import copy
27 |
28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29 |
30 |
31 | class UnCLIPPipeline(DiffusionPipeline):
32 | """
33 | Pipeline for text-to-image generation using unCLIP.
34 |
35 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
36 | implemented for all pipelines (downloading, saving, running on a particular device, etc.).
37 |
38 | Args:
39 | text_encoder ([`~transformers.CLIPTextModelWithProjection`]):
40 | Frozen text-encoder.
41 | tokenizer ([`~transformers.CLIPTokenizer`]):
42 | A `CLIPTokenizer` to tokenize text.
43 | prior ([`PriorTransformer`]):
44 | The canonical unCLIP prior to approximate the image embedding from the text embedding.
45 | text_proj ([`UnCLIPTextProjModel`]):
46 | Utility class to prepare and combine the embeddings before they are passed to the decoder.
47 | decoder ([`UNet2DConditionModel`]):
48 | The decoder to invert the image embedding into an image.
49 | super_res_first ([`UNet2DModel`]):
50 | Super resolution UNet. Used in all but the last step of the super resolution diffusion process.
51 | super_res_last ([`UNet2DModel`]):
52 | Super resolution UNet. Used in the last step of the super resolution diffusion process.
53 | prior_scheduler ([`UnCLIPScheduler`]):
54 | Scheduler used in the prior denoising process (a modified [`DDPMScheduler`]).
55 | decoder_scheduler ([`UnCLIPScheduler`]):
56 | Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]).
57 | super_res_scheduler ([`UnCLIPScheduler`]):
58 | Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]).
59 |
60 | """
61 |
62 | _exclude_from_cpu_offload = ["prior"]
63 |
64 | prior: PriorTransformer
65 | decoder: UNet2DConditionModel
66 | text_proj: UnCLIPTextProjModel
67 | text_encoder: CLIPTextModelWithProjection
68 | tokenizer: CLIPTokenizer
69 | super_res_first: UNet2DModel
70 | super_res_last: UNet2DModel
71 |
72 | prior_scheduler: UnCLIPScheduler
73 | decoder_scheduler: UnCLIPScheduler
74 | super_res_scheduler: UnCLIPScheduler
75 |
76 | def __init__(
77 | self,
78 | prior: PriorTransformer,
79 | decoder: UNet2DConditionModel,
80 | text_encoder: CLIPTextModelWithProjection,
81 | tokenizer: CLIPTokenizer,
82 | text_proj: UnCLIPTextProjModel,
83 | super_res_first: UNet2DModel,
84 | super_res_last: UNet2DModel,
85 | prior_scheduler: UnCLIPScheduler,
86 | decoder_scheduler: UnCLIPScheduler,
87 | super_res_scheduler: UnCLIPScheduler,
88 | ):
89 | super().__init__()
90 |
91 | self.register_modules(
92 | prior=prior,
93 | decoder=decoder,
94 | text_encoder=text_encoder,
95 | tokenizer=tokenizer,
96 | text_proj=text_proj,
97 | super_res_first=super_res_first,
98 | super_res_last=super_res_last,
99 | prior_scheduler=prior_scheduler,
100 | decoder_scheduler=decoder_scheduler,
101 | super_res_scheduler=super_res_scheduler,
102 | )
103 |
104 | def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
105 | if latents is None:
106 | latents = randn_tensor(
107 | shape, generator=generator, device=device, dtype=dtype
108 | )
109 | else:
110 | if latents.shape != shape:
111 | raise ValueError(
112 | f"Unexpected latents shape, got {latents.shape}, expected {shape}"
113 | )
114 | latents = latents.to(device)
115 |
116 | latents = latents * scheduler.init_noise_sigma
117 | return latents
118 |
119 | def _encode_prompt(
120 | self,
121 | prompt,
122 | device,
123 | num_images_per_prompt,
124 | do_classifier_free_guidance,
125 | text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
126 | text_attention_mask: Optional[torch.Tensor] = None,
127 | ):
128 | if text_model_output is None:
129 | batch_size = len(prompt) if isinstance(prompt, list) else 1
130 | # get prompt text embeddings
131 | text_inputs = self.tokenizer(
132 | prompt,
133 | padding="max_length",
134 | max_length=self.tokenizer.model_max_length,
135 | truncation=True,
136 | return_tensors="pt",
137 | )
138 | text_input_ids = text_inputs.input_ids
139 | text_mask = text_inputs.attention_mask.bool().to(device)
140 |
141 | untruncated_ids = self.tokenizer(
142 | prompt, padding="longest", return_tensors="pt"
143 | ).input_ids
144 |
145 | if untruncated_ids.shape[-1] >= text_input_ids.shape[
146 | -1
147 | ] and not torch.equal(text_input_ids, untruncated_ids):
148 | removed_text = self.tokenizer.batch_decode(
149 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
150 | )
151 | logger.warning(
152 | "The following part of your input was truncated because CLIP can only handle sequences up to"
153 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
154 | )
155 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
156 |
157 | text_encoder_output = self.text_encoder(text_input_ids.to(device))
158 |
159 | prompt_embeds = text_encoder_output.text_embeds
160 | text_encoder_hidden_states = text_encoder_output.last_hidden_state
161 |
162 | else:
163 | batch_size = text_model_output[0].shape[0]
164 | prompt_embeds, text_encoder_hidden_states = (
165 | text_model_output[0],
166 | text_model_output[1],
167 | )
168 | text_mask = text_attention_mask
169 |
170 | prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
171 | text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(
172 | num_images_per_prompt, dim=0
173 | )
174 | text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
175 |
176 | if do_classifier_free_guidance:
177 | uncond_tokens = [""] * batch_size
178 |
179 | uncond_input = self.tokenizer(
180 | uncond_tokens,
181 | padding="max_length",
182 | max_length=self.tokenizer.model_max_length,
183 | truncation=True,
184 | return_tensors="pt",
185 | )
186 | uncond_text_mask = uncond_input.attention_mask.bool().to(device)
187 | negative_prompt_embeds_text_encoder_output = self.text_encoder(
188 | uncond_input.input_ids.to(device)
189 | )
190 |
191 | negative_prompt_embeds = (
192 | negative_prompt_embeds_text_encoder_output.text_embeds
193 | )
194 | uncond_text_encoder_hidden_states = (
195 | negative_prompt_embeds_text_encoder_output.last_hidden_state
196 | )
197 |
198 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
199 |
200 | seq_len = negative_prompt_embeds.shape[1]
201 | negative_prompt_embeds = negative_prompt_embeds.repeat(
202 | 1, num_images_per_prompt
203 | )
204 | negative_prompt_embeds = negative_prompt_embeds.view(
205 | batch_size * num_images_per_prompt, seq_len
206 | )
207 |
208 | seq_len = uncond_text_encoder_hidden_states.shape[1]
209 | uncond_text_encoder_hidden_states = (
210 | uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
211 | )
212 | uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
213 | batch_size * num_images_per_prompt, seq_len, -1
214 | )
215 | uncond_text_mask = uncond_text_mask.repeat_interleave(
216 | num_images_per_prompt, dim=0
217 | )
218 |
219 | # done duplicates
220 |
221 | # For classifier free guidance, we need to do two forward passes.
222 | # Here we concatenate the unconditional and text embeddings into a single batch
223 | # to avoid doing two forward passes
224 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
225 | text_encoder_hidden_states = torch.cat(
226 | [uncond_text_encoder_hidden_states, text_encoder_hidden_states]
227 | )
228 |
229 | text_mask = torch.cat([uncond_text_mask, text_mask])
230 |
231 | return prompt_embeds, text_encoder_hidden_states, text_mask
232 |
233 | @torch.no_grad()
234 | def __call__(
235 | self,
236 | prompt: Optional[Union[str, List[str]]] = None,
237 | num_images_per_prompt: int = 1,
238 | prior_num_inference_steps: int = 25,
239 | decoder_num_inference_steps: int = 25,
240 | super_res_num_inference_steps: int = 7,
241 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
242 | prior_latents: Optional[torch.FloatTensor] = None,
243 | decoder_latents: Optional[torch.FloatTensor] = None,
244 | super_res_latents: Optional[torch.FloatTensor] = None,
245 | text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
246 | text_attention_mask: Optional[torch.Tensor] = None,
247 | prior_guidance_scale: float = 4.0,
248 | decoder_guidance_scale: float = 8.0,
249 | output_type: Optional[str] = "pil",
250 | return_dict: bool = True,
251 | null_prompt_decoder: bool = False,
252 | ):
253 | """
254 | The call function to the pipeline for generation.
255 |
256 | Args:
257 | prompt (`str` or `List[str]`):
258 | The prompt or prompts to guide image generation. This can only be left undefined if `text_model_output`
259 | and `text_attention_mask` is passed.
260 | num_images_per_prompt (`int`, *optional*, defaults to 1):
261 | The number of images to generate per prompt.
262 | prior_num_inference_steps (`int`, *optional*, defaults to 25):
263 | The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
264 | image at the expense of slower inference.
265 | decoder_num_inference_steps (`int`, *optional*, defaults to 25):
266 | The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
267 | image at the expense of slower inference.
268 | super_res_num_inference_steps (`int`, *optional*, defaults to 7):
269 | The number of denoising steps for super resolution. More denoising steps usually lead to a higher
270 | quality image at the expense of slower inference.
271 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
272 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
273 | generation deterministic.
274 | prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
275 | Pre-generated noisy latents to be used as inputs for the prior.
276 | decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
277 | Pre-generated noisy latents to be used as inputs for the decoder.
278 | super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
279 | Pre-generated noisy latents to be used as inputs for the decoder.
280 | prior_guidance_scale (`float`, *optional*, defaults to 4.0):
281 | A higher guidance scale value encourages the model to generate images closely linked to the text
282 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
283 | decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
284 | A higher guidance scale value encourages the model to generate images closely linked to the text
285 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
286 | text_model_output (`CLIPTextModelOutput`, *optional*):
287 | Pre-defined [`CLIPTextModel`] outputs that can be derived from the text encoder. Pre-defined text
288 | outputs can be passed for tasks like text embedding interpolations. Make sure to also pass
289 | `text_attention_mask` in this case. `prompt` can the be left `None`.
290 | text_attention_mask (`torch.Tensor`, *optional*):
291 | Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
292 | masks are necessary when passing `text_model_output`.
293 | output_type (`str`, *optional*, defaults to `"pil"`):
294 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
295 | return_dict (`bool`, *optional*, defaults to `True`):
296 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
297 |
298 | Returns:
299 | [`~pipelines.ImagePipelineOutput`] or `tuple`:
300 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
301 | returned where the first element is a list with the generated images.
302 | """
303 | if prompt is not None:
304 | if isinstance(prompt, str):
305 | batch_size = 1
306 | elif isinstance(prompt, list):
307 | batch_size = len(prompt)
308 | else:
309 | raise ValueError(
310 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
311 | )
312 | else:
313 | batch_size = text_model_output[0].shape[0]
314 |
315 | device = self._execution_device
316 |
317 | batch_size = batch_size * num_images_per_prompt
318 |
319 | prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
320 | prompt,
321 | device,
322 | num_images_per_prompt,
323 | False,
324 | text_model_output,
325 | text_attention_mask,
326 | )
327 |
328 | hidden_states = randn_tensor(
329 | (batch_size, prompt_embeds.shape[-1]),
330 | device=prompt_embeds.device,
331 | dtype=prompt_embeds.dtype,
332 | generator=generator,
333 | )
334 |
335 | prior_latents = self.prior(
336 | hidden_states,
337 | proj_embedding=prompt_embeds,
338 | encoder_hidden_states=text_encoder_hidden_states,
339 | attention_mask=text_mask,
340 | ).predicted_image_embedding
341 |
342 | do_classifier_free_guidance = decoder_guidance_scale > 1.0
343 | prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
344 | prompt if not null_prompt_decoder else "",
345 | device,
346 | num_images_per_prompt,
347 | do_classifier_free_guidance,
348 | text_model_output,
349 | text_attention_mask,
350 | )
351 |
352 | prior_latents = prior_latents.expand(
353 | (
354 | prompt_embeds.shape[0] // 2
355 | if do_classifier_free_guidance
356 | else prompt_embeds.shape[0],
357 | prompt_embeds.shape[1],
358 | )
359 | )
360 | image_embeddings = prior_latents.clone()
361 | # return image_embeddings
362 |
363 | # decoder
364 | text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
365 | image_embeddings=image_embeddings,
366 | prompt_embeds=prompt_embeds,
367 | text_encoder_hidden_states=text_encoder_hidden_states,
368 | do_classifier_free_guidance=do_classifier_free_guidance,
369 | )
370 |
371 | if device.type == "mps":
372 | # HACK: MPS: There is a panic when padding bool tensors,
373 | # so cast to int tensor for the pad and back to bool afterwards
374 | text_mask = text_mask.type(torch.int)
375 | decoder_text_mask = F.pad(
376 | text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1
377 | )
378 | decoder_text_mask = decoder_text_mask.type(torch.bool)
379 | else:
380 | decoder_text_mask = F.pad(
381 | text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True
382 | )
383 |
384 | self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
385 | decoder_timesteps_tensor = self.decoder_scheduler.timesteps
386 |
387 | num_channels_latents = self.decoder.config.in_channels
388 | height = self.decoder.config.sample_size
389 | width = self.decoder.config.sample_size
390 |
391 | decoder_latents = self.prepare_latents(
392 | (batch_size, num_channels_latents, height, width),
393 | text_encoder_hidden_states.dtype,
394 | device,
395 | generator,
396 | decoder_latents,
397 | self.decoder_scheduler,
398 | )
399 |
400 | for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
401 | # expand the latents if we are doing classifier free guidance
402 | latent_model_input = (
403 | torch.cat([decoder_latents] * 2)
404 | if do_classifier_free_guidance
405 | else decoder_latents
406 | )
407 |
408 | noise_pred = self.decoder(
409 | sample=latent_model_input,
410 | timestep=t,
411 | encoder_hidden_states=text_encoder_hidden_states,
412 | class_labels=additive_clip_time_embeddings,
413 | attention_mask=decoder_text_mask,
414 | ).sample
415 |
416 | if do_classifier_free_guidance:
417 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
418 | noise_pred_uncond, _ = noise_pred_uncond.split(
419 | latent_model_input.shape[1], dim=1
420 | )
421 | noise_pred_text, predicted_variance = noise_pred_text.split(
422 | latent_model_input.shape[1], dim=1
423 | )
424 | noise_pred = noise_pred_uncond + decoder_guidance_scale * (
425 | noise_pred_text - noise_pred_uncond
426 | )
427 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
428 |
429 | if i + 1 == decoder_timesteps_tensor.shape[0]:
430 | prev_timestep = None
431 | else:
432 | prev_timestep = decoder_timesteps_tensor[i + 1]
433 |
434 | # compute the previous noisy sample x_t -> x_t-1
435 | decoder_latents = self.decoder_scheduler.step(
436 | noise_pred,
437 | t,
438 | decoder_latents,
439 | prev_timestep=prev_timestep,
440 | generator=generator,
441 | ).prev_sample
442 |
443 | decoder_latents = decoder_latents.clamp(-1, 1)
444 |
445 | image_small = decoder_latents
446 |
447 | # done decoder
448 |
449 | # super res
450 |
451 | self.super_res_scheduler.set_timesteps(
452 | super_res_num_inference_steps, device=device
453 | )
454 | super_res_timesteps_tensor = self.super_res_scheduler.timesteps
455 |
456 | channels = self.super_res_first.config.in_channels // 2
457 | height = self.super_res_first.config.sample_size
458 | width = self.super_res_first.config.sample_size
459 |
460 | super_res_latents = self.prepare_latents(
461 | (batch_size, channels, height, width),
462 | image_small.dtype,
463 | device,
464 | generator,
465 | super_res_latents,
466 | self.super_res_scheduler,
467 | )
468 |
469 | if device.type == "mps":
470 | # MPS does not support many interpolations
471 | image_upscaled = F.interpolate(image_small, size=[height, width])
472 | else:
473 | interpolate_antialias = {}
474 | if "antialias" in inspect.signature(F.interpolate).parameters:
475 | interpolate_antialias["antialias"] = True
476 |
477 | image_upscaled = F.interpolate(
478 | image_small,
479 | size=[height, width],
480 | mode="bicubic",
481 | align_corners=False,
482 | **interpolate_antialias,
483 | )
484 |
485 | for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
486 | # no classifier free guidance
487 |
488 | if i == super_res_timesteps_tensor.shape[0] - 1:
489 | unet = self.super_res_last
490 | else:
491 | unet = self.super_res_first
492 |
493 | latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
494 |
495 | noise_pred = unet(
496 | sample=latent_model_input,
497 | timestep=t,
498 | ).sample
499 |
500 | if i + 1 == super_res_timesteps_tensor.shape[0]:
501 | prev_timestep = None
502 | else:
503 | prev_timestep = super_res_timesteps_tensor[i + 1]
504 |
505 | # compute the previous noisy sample x_t -> x_t-1
506 | super_res_latents = self.super_res_scheduler.step(
507 | noise_pred,
508 | t,
509 | super_res_latents,
510 | prev_timestep=prev_timestep,
511 | generator=generator,
512 | ).prev_sample
513 |
514 | image = super_res_latents
515 | # done super res
516 |
517 | # post processing
518 |
519 | image = image * 0.5 + 0.5
520 | image = image.clamp(0, 1)
521 | image = image.cpu().permute(0, 2, 3, 1).float().numpy()
522 |
523 | if output_type == "pil":
524 | image = self.numpy_to_pil(image)
525 |
526 | if not return_dict:
527 | return (image,)
528 |
529 | return ImagePipelineOutput(images=image)
530 |
--------------------------------------------------------------------------------
/src/priors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/src/priors/__init__.py
--------------------------------------------------------------------------------
/src/priors/prior_transformer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | from dataclasses import dataclass
5 | from typing import Dict, Optional, Union
6 |
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch import nn
11 |
12 | from diffusers.configuration_utils import ConfigMixin, register_to_config
13 | from diffusers.utils import BaseOutput
14 | from diffusers.models.attention import BasicTransformerBlock
15 | from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
16 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17 | from diffusers.models.modeling_utils import ModelMixin
18 |
19 |
20 | @dataclass
21 | class PriorTransformerOutput(BaseOutput):
22 | """
23 | The output of [`PriorTransformer`].
24 |
25 | Args:
26 | predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
27 | The predicted CLIP image embedding conditioned on the CLIP text embedding input.
28 | """
29 |
30 | predicted_image_embedding: torch.FloatTensor
31 |
32 |
33 | class PriorTransformer(ModelMixin, ConfigMixin):
34 | """
35 | A Prior Transformer model.
36 |
37 | Parameters:
38 | num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
39 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
40 | num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
41 | embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
42 | num_embeddings (`int`, *optional*, defaults to 77):
43 | The number of embeddings of the model input `hidden_states`
44 | additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
45 | projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
46 | additional_embeddings`.
47 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48 | time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
49 | The activation function to use to create timestep embeddings.
50 | norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
51 | passing to Transformer blocks. Set it to `None` if normalization is not needed.
52 | embedding_proj_norm_type (`str`, *optional*, defaults to None):
53 | The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
54 | needed.
55 | encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
56 | The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
57 | `encoder_hidden_states` is `None`.
58 | added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
59 | Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
60 | product between the text embedding and image embedding as proposed in the unclip paper
61 | https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
62 | time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
63 | If None, will be set to `num_attention_heads * attention_head_dim`
64 | embedding_proj_dim (`int`, *optional*, default to None):
65 | The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
66 | clip_embed_dim (`int`, *optional*, default to None):
67 | The dimension of the output. If None, will be set to `embedding_dim`.
68 | """
69 |
70 | @register_to_config
71 | def __init__(
72 | self,
73 | num_attention_heads: int = 32,
74 | attention_head_dim: int = 64,
75 | num_layers: int = 20,
76 | embedding_dim: int = 768,
77 | num_embeddings=77,
78 | additional_embeddings=3, # as we have remvoed the time embedding
79 | dropout: float = 0.0,
80 | # time_embed_act_fn: str = "silu",
81 | norm_in_type: Optional[str] = None, # layer
82 | embedding_proj_norm_type: Optional[str] = None, # layer
83 | encoder_hid_proj_type: Optional[str] = "linear", # linear
84 | added_emb_type: Optional[str] = "prd", # prd
85 | # time_embed_dim: Optional[int] = None,
86 | embedding_proj_dim: Optional[int] = None,
87 | clip_embed_dim: Optional[int] = None,
88 | ):
89 | super().__init__()
90 | self.num_attention_heads = num_attention_heads
91 | self.attention_head_dim = attention_head_dim
92 | inner_dim = num_attention_heads * attention_head_dim
93 | self.additional_embeddings = additional_embeddings
94 |
95 | # time_embed_dim = time_embed_dim or inner_dim
96 | embedding_proj_dim = embedding_proj_dim or embedding_dim
97 | clip_embed_dim = clip_embed_dim or embedding_dim
98 |
99 | # self.time_proj = Timesteps(inner_dim, True, 0)
100 | # self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
101 |
102 | self.proj_in = nn.Linear(embedding_dim, inner_dim)
103 |
104 | if embedding_proj_norm_type is None:
105 | self.embedding_proj_norm = None
106 | elif embedding_proj_norm_type == "layer":
107 | self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
108 | else:
109 | raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
110 |
111 | self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
112 |
113 | if encoder_hid_proj_type is None:
114 | self.encoder_hidden_states_proj = None
115 | elif encoder_hid_proj_type == "linear":
116 | self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
117 | else:
118 | raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
119 |
120 | self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
121 |
122 | if added_emb_type == "prd":
123 | self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
124 | elif added_emb_type is None:
125 | self.prd_embedding = None
126 | else:
127 | raise ValueError(
128 | f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
129 | )
130 |
131 | self.transformer_blocks = nn.ModuleList(
132 | [
133 | BasicTransformerBlock(
134 | inner_dim,
135 | num_attention_heads,
136 | attention_head_dim,
137 | dropout=dropout,
138 | activation_fn="gelu",
139 | attention_bias=True,
140 | )
141 | for d in range(num_layers)
142 | ]
143 | )
144 |
145 | if norm_in_type == "layer":
146 | self.norm_in = nn.LayerNorm(inner_dim)
147 | elif norm_in_type is None:
148 | self.norm_in = None
149 | else:
150 | raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
151 |
152 | self.norm_out = nn.LayerNorm(inner_dim)
153 |
154 | self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
155 |
156 | causal_attention_mask = torch.full(
157 | [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
158 | )
159 | causal_attention_mask.triu_(1)
160 | causal_attention_mask = causal_attention_mask[None, ...]
161 | self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
162 |
163 | self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
164 | self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
165 |
166 | @property
167 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
168 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
169 | r"""
170 | Returns:
171 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
172 | indexed by its weight name.
173 | """
174 | # set recursively
175 | processors = {}
176 |
177 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
178 | if hasattr(module, "set_processor"):
179 | processors[f"{name}.processor"] = module.processor
180 |
181 | for sub_name, child in module.named_children():
182 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
183 |
184 | return processors
185 |
186 | for name, module in self.named_children():
187 | fn_recursive_add_processors(name, module, processors)
188 |
189 | return processors
190 |
191 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
192 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
193 | r"""
194 | Sets the attention processor to use to compute attention.
195 |
196 | Parameters:
197 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
198 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
199 | for **all** `Attention` layers.
200 |
201 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
202 | processor. This is strongly recommended when setting trainable attention processors.
203 |
204 | """
205 | count = len(self.attn_processors.keys())
206 |
207 | if isinstance(processor, dict) and len(processor) != count:
208 | raise ValueError(
209 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
210 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
211 | )
212 |
213 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
214 | if hasattr(module, "set_processor"):
215 | if not isinstance(processor, dict):
216 | module.set_processor(processor)
217 | else:
218 | module.set_processor(processor.pop(f"{name}.processor"))
219 |
220 | for sub_name, child in module.named_children():
221 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
222 |
223 | for name, module in self.named_children():
224 | fn_recursive_attn_processor(name, module, processor)
225 |
226 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
227 | def set_default_attn_processor(self):
228 | """
229 | Disables custom attention processors and sets the default attention implementation.
230 | """
231 | self.set_attn_processor(AttnProcessor())
232 |
233 | def forward(
234 | self,
235 | hidden_states,
236 | # timestep: Union[torch.Tensor, float, int],
237 | proj_embedding: torch.FloatTensor,
238 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
239 | attention_mask: Optional[torch.BoolTensor] = None,
240 | return_dict: bool = True,
241 | ):
242 | """
243 | The [`PriorTransformer`] forward method.
244 |
245 | Args:
246 | hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
247 | The currently predicted image embeddings.
248 | timestep (`torch.LongTensor`):
249 | Current denoising step.
250 | proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
251 | Projected embedding vector the denoising process is conditioned on.
252 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
253 | Hidden states of the text embeddings the denoising process is conditioned on.
254 | attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
255 | Text mask for the text embeddings.
256 | return_dict (`bool`, *optional*, defaults to `True`):
257 | Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
258 | tuple.
259 |
260 | Returns:
261 | [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
262 | If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
263 | tuple is returned where the first element is the sample tensor.
264 | """
265 | batch_size = hidden_states.shape[0]
266 |
267 | # timesteps = timestep
268 | # if not torch.is_tensor(timesteps):
269 | # timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
270 | # elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
271 | # timesteps = timesteps[None].to(hidden_states.device)
272 |
273 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
274 | # timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
275 |
276 | # timesteps_projected = self.time_proj(timesteps)
277 |
278 | # timesteps does not contain any weights and will always return f32 tensors
279 | # but time_embedding might be fp16, so we need to cast here.
280 | # timesteps_projected = timesteps_projected.to(dtype=self.dtype)
281 | # time_embeddings = self.time_embedding(timesteps_projected)
282 |
283 | if self.embedding_proj_norm is not None:
284 | proj_embedding = self.embedding_proj_norm(proj_embedding)
285 |
286 | proj_embeddings = self.embedding_proj(proj_embedding)
287 | if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
288 | encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
289 | elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
290 | raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
291 |
292 | hidden_states = self.proj_in(hidden_states)
293 |
294 | positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
295 |
296 | additional_embeds = []
297 | additional_embeddings_len = 0
298 |
299 | if encoder_hidden_states is not None:
300 | additional_embeds.append(encoder_hidden_states)
301 | additional_embeddings_len += encoder_hidden_states.shape[1]
302 |
303 | if len(proj_embeddings.shape) == 2:
304 | proj_embeddings = proj_embeddings[:, None, :]
305 |
306 | if len(hidden_states.shape) == 2:
307 | hidden_states = hidden_states[:, None, :]
308 |
309 | additional_embeds = additional_embeds + [
310 | proj_embeddings,
311 | # time_embeddings[:, None, :],
312 | hidden_states,
313 | ]
314 |
315 | if self.prd_embedding is not None:
316 | prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
317 | additional_embeds.append(prd_embedding)
318 |
319 | hidden_states = torch.cat(
320 | additional_embeds,
321 | dim=1,
322 | )
323 |
324 | # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
325 | additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
326 | if positional_embeddings.shape[1] < hidden_states.shape[1]:
327 | positional_embeddings = F.pad(
328 | positional_embeddings,
329 | (
330 | 0,
331 | 0,
332 | additional_embeddings_len,
333 | self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
334 | ),
335 | value=0.0,
336 | )
337 |
338 | hidden_states = hidden_states + positional_embeddings
339 |
340 | if attention_mask is not None:
341 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
342 | attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
343 | attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
344 | attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
345 |
346 | if self.norm_in is not None:
347 | hidden_states = self.norm_in(hidden_states)
348 |
349 | for block in self.transformer_blocks:
350 | hidden_states = block(hidden_states, attention_mask=attention_mask)
351 |
352 | hidden_states = self.norm_out(hidden_states)
353 |
354 | if self.prd_embedding is not None:
355 | hidden_states = hidden_states[:, -1]
356 | else:
357 | hidden_states = hidden_states[:, additional_embeddings_len:]
358 |
359 | predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
360 |
361 | if not return_dict:
362 | return (predicted_image_embedding,)
363 |
364 | return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
365 |
366 | def post_process_latents(self, prior_latents):
367 | prior_latents = (prior_latents * self.clip_std) + self.clip_mean
368 | return prior_latents
--------------------------------------------------------------------------------