├── .gitignore ├── LICENSE ├── README.md ├── assets ├── eclipse_solar_eclipse.png ├── example.png └── results.png ├── main.py ├── requirements.txt └── src ├── pipelines ├── __init__.py ├── pipeline_kandinsky_prior.py └── pipeline_unclip.py └── priors ├── __init__.py └── prior_transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 eclipse-t2i 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ##
[CVPR 2024] ECLIPSE: Revisiting the Text-to-Image Prior for Effecient Image Generation
2 | 3 |
4 |   5 |   6 |   7 | 8 | Solar Eclipse image generated by ECLIPSE 9 | 10 |
11 | 12 | --- 13 | 14 | This repository contains the inference code for our paper, ECLIPSE. 15 | We show how to utilize the pre-trained ECLIPSE text-to-image prior associated with diffusion image decoders such as Karlo and Kandinsky. 16 | 17 | - ECLIPSE presents the tiny prior learning strategy that compresses the previous prior models from 1 billion parameters down to 33 million parameters. 18 | - Additionally, ECLIPSE prior is trained on a mere 5 million image-text (alt-text) pairs. 19 | 20 | > **_News:_** Checkout our latest work, [λ-ECLIPSE](https://eclipse-t2i.github.io/Lambda-ECLIPSE/) extending the T2I priors for effecient zero-shot multi-subject driven text-to-image generations. 21 | 22 | 23 | **Please follow the below steps to run the inference locally.** 24 | 25 | --- 26 | 27 | **Qualitative Comparisons:** 28 | ![Examples](./assets/example.png) 29 | 30 | 31 | **Quantitative Comparisons:** 32 | ![Results](./assets/results.png) 33 | 34 | ## TODOs: 35 | 36 | - [x] ~~Release ECLIPSE priors for Kandinsky v2.2 and Karlo-v1-alpha.~~ 37 | - [x] ~~Release the demo.~~ 38 | - [ ] Release ECLIPSE prior with Kandinsky v2.2 LCM decoder. (soon!) 39 | - [ ] Release ECLIPSE prior training code. (will be released in seperate repository) 40 | 41 | 42 | ## Setup 43 | 44 | ### Installation 45 | ```bash 46 | git clone git@github.com:eclipse-t2i/eclipse-inference.git 47 | 48 | conda create -p ./venv python=3.9 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | ### Demo 53 | ```bash 54 | conda activate ./venv 55 | gradio main.py 56 | ``` 57 | 58 | ## Run Inference 59 | 60 | This repository supports two pre-trained image decoders: [Karlo-v1-alpha](https://huggingface.co/kakaobrain/karlo-v1-alpha) and [Kandinsky-v2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). 61 | 62 | **Note:** ECLIPSE prior is not a diffusion model -- while image decoders are. 63 | 64 | 65 | ### Kandinsky Inference 66 | ```python 67 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer 68 | from src.pipelines.pipeline_kandinsky_prior import KandinskyPriorPipeline 69 | from src.priors.prior_transformer import PriorTransformer 70 | from diffusers import DiffusionPipeline 71 | 72 | text_encoder = ( 73 | CLIPTextModelWithProjection.from_pretrained( 74 | "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", 75 | projection_dim=1280, 76 | torch_dtype=torch.float32, 77 | ) 78 | ) 79 | 80 | tokenizer = CLIPTokenizer.from_pretrained( 81 | "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" 82 | ) 83 | 84 | prior = PriorTransformer.from_pretrained("ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior") 85 | pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", 86 | prior=prior, 87 | text_encoder=text_encoder, 88 | tokenizer=tokenizer, 89 | ).to("cuda") 90 | 91 | pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder").to("cuda") 92 | 93 | prompt = "black apples in the basket" 94 | image_embeds, negative_image_embeds = pipe_prior(prompt).to_tuple() 95 | images = pipe( 96 | num_inference_steps=50, 97 | image_embeds=image_embeds, 98 | negative_image_embeds=negative_image_embeds, 99 | ).images 100 | 101 | images[0] 102 | ``` 103 | 104 | 105 | ### Karlo Inference 106 | ```python 107 | from src.pipelines.pipeline_unclip import UnCLIPPipeline 108 | from src.priors.prior_transformer import PriorTransformer 109 | 110 | prior = PriorTransformer.from_pretrained("ECLIPSE-Community/ECLIPSE_Karlo_Prior") 111 | pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", prior=prior).to("cuda") 112 | 113 | prompt="black apples in the basket" 114 | images = pipe(prompt, decoder_guidance_scale=7.5).images 115 | 116 | images[0] 117 | ``` 118 | 119 | # Acknowledgement 120 | 121 | We would like to acknoweldge excellent open-source text-to-image models (Kalro and Kandinsky) without them this work would not have been possible. Also, we thank HuggingFace for streamlining the T2I models. 122 | -------------------------------------------------------------------------------- /assets/eclipse_solar_eclipse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/assets/eclipse_solar_eclipse.png -------------------------------------------------------------------------------- /assets/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/assets/example.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/assets/results.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from PIL import Image 3 | 4 | import torch 5 | 6 | from torchvision import transforms 7 | from transformers import ( 8 | CLIPProcessor, 9 | CLIPModel, 10 | CLIPTokenizer, 11 | CLIPTextModelWithProjection, 12 | CLIPVisionModelWithProjection, 13 | CLIPFeatureExtractor, 14 | ) 15 | 16 | import math 17 | from typing import List 18 | from PIL import Image, ImageChops 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers import UnCLIPPipeline 23 | 24 | # from diffusers.utils.torch_utils import randn_tensor 25 | 26 | from transformers import CLIPTokenizer 27 | 28 | from src.priors.prior_transformer import ( 29 | PriorTransformer, 30 | ) # original huggingface prior transformer without time conditioning 31 | from src.pipelines.pipeline_kandinsky_prior import KandinskyPriorPipeline 32 | 33 | from diffusers import DiffusionPipeline 34 | 35 | 36 | __DEVICE__ = "cpu" 37 | if torch.cuda.is_available(): 38 | __DEVICE__ = "cuda" 39 | 40 | class Ours: 41 | def __init__(self, device): 42 | text_encoder = ( 43 | CLIPTextModelWithProjection.from_pretrained( 44 | "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", 45 | projection_dim=1280, 46 | torch_dtype=torch.float16, 47 | ) 48 | .eval() 49 | .requires_grad_(False) 50 | ) 51 | 52 | tokenizer = CLIPTokenizer.from_pretrained( 53 | "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" 54 | ) 55 | 56 | prior = PriorTransformer.from_pretrained( 57 | "ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior", 58 | torch_dtype=torch.float16, 59 | ) 60 | 61 | self.pipe_prior = KandinskyPriorPipeline.from_pretrained( 62 | "kandinsky-community/kandinsky-2-2-prior", 63 | prior=prior, 64 | text_encoder=text_encoder, 65 | tokenizer=tokenizer, 66 | torch_dtype=torch.float16, 67 | ).to(device) 68 | 69 | self.pipe = DiffusionPipeline.from_pretrained( 70 | "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 71 | ).to(device) 72 | 73 | def inference(self, text, negative_text, steps, guidance_scale): 74 | gen_images = [] 75 | for i in range(1): 76 | image_emb, negative_image_emb = self.pipe_prior( 77 | text, negative_prompt=negative_text 78 | ).to_tuple() 79 | image = self.pipe( 80 | image_embeds=image_emb, 81 | negative_image_embeds=negative_image_emb, 82 | num_inference_steps=steps, 83 | guidance_scale=guidance_scale, 84 | ).images 85 | gen_images.append(image[0]) 86 | return gen_images 87 | 88 | 89 | selected_model = Ours(device=__DEVICE__) 90 | 91 | 92 | def get_images(text, negative_text, steps, guidance_scale): 93 | images = selected_model.inference(text, negative_text, steps, guidance_scale) 94 | new_images = [] 95 | for img in images: 96 | new_images.append(img) 97 | return new_images[0] 98 | 99 | 100 | with gr.Blocks() as demo: 101 | gr.Markdown( 102 | """

ECLIPSE: Revisiting the Text-to-Image Prior for Effecient Image Generation

103 |

Project Page | Paper

104 | """ 105 | ) 106 | 107 | with gr.Group(): 108 | with gr.Row(): 109 | with gr.Column(): 110 | text = gr.Textbox( 111 | label="Enter your prompt", 112 | show_label=False, 113 | max_lines=1, 114 | placeholder="Enter your prompt", 115 | elem_id="prompt-text-input", 116 | ).style( 117 | border=(True, False, True, True), 118 | rounded=(True, False, False, True), 119 | container=False, 120 | ) 121 | 122 | with gr.Row(): 123 | with gr.Column(): 124 | negative_text = gr.Textbox( 125 | label="Enter your negative prompt", 126 | show_label=False, 127 | max_lines=1, 128 | placeholder="Enter your negative prompt", 129 | elem_id="prompt-text-input", 130 | ).style( 131 | border=(True, False, True, True), 132 | rounded=(True, False, False, True), 133 | container=False, 134 | ) 135 | 136 | with gr.Row(): 137 | steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=1) 138 | guidance_scale = gr.Slider( 139 | label="Guidance Scale", minimum=0, maximum=10, value=7.5, step=0.1 140 | ) 141 | 142 | with gr.Row(): 143 | btn = gr.Button(value="Generate Image", full_width=False) 144 | 145 | gallery = gr.Image( 146 | height=512, width=512, label="Generated images", show_label=True, elem_id="gallery" 147 | ).style(preview=False, columns=1) 148 | 149 | btn.click( 150 | get_images, 151 | inputs=[ 152 | text, 153 | negative_text, 154 | steps, 155 | guidance_scale, 156 | ], 157 | outputs=gallery, 158 | ) 159 | text.submit( 160 | get_images, 161 | inputs=[ 162 | text, 163 | negative_text, 164 | steps, 165 | guidance_scale, 166 | ], 167 | outputs=gallery, 168 | ) 169 | negative_text.submit( 170 | get_images, 171 | inputs=[ 172 | text, 173 | negative_text, 174 | steps, 175 | guidance_scale, 176 | ], 177 | outputs=gallery, 178 | ) 179 | 180 | with gr.Accordion(label="Ethics & Privacy", open=False): 181 | gr.HTML( 182 | """
183 |

Privacy

184 | We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI. 185 |

Biases and content acknowledgment

186 | This model will have the same biases as pre-trained CLIP model.
187 | """ 188 | ) 189 | 190 | if __name__ == "__main__": 191 | demo.queue(max_size=20).launch() 192 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.24.0 2 | datasets==2.14.6 3 | diffusers==0.20.2 4 | numpy==1.26.1 5 | packaging==23.2 6 | pandas_stubs==1.2.0.57 7 | Pillow==10.1.0 8 | torch==2.0.0 9 | torchvision==0.15.1 10 | tqdm==4.66.1 11 | transformers==4.34.1 12 | gradio 13 | jmespath 14 | opencv-python 15 | PyWavelet 16 | gradio==3.47.1 -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/src/pipelines/__init__.py -------------------------------------------------------------------------------- /src/pipelines/pipeline_kandinsky_prior.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Union 3 | 4 | import numpy as np 5 | import PIL 6 | import torch 7 | from transformers import ( 8 | CLIPImageProcessor, 9 | CLIPTextModelWithProjection, 10 | CLIPTokenizer, 11 | CLIPVisionModelWithProjection, 12 | ) 13 | 14 | from diffusers.models import PriorTransformer 15 | from diffusers.schedulers import UnCLIPScheduler 16 | from diffusers.utils import ( 17 | BaseOutput, 18 | is_accelerate_available, 19 | is_accelerate_version, 20 | logging, 21 | randn_tensor, 22 | replace_example_docstring, 23 | ) 24 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 25 | 26 | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 28 | 29 | EXAMPLE_DOC_STRING = """ 30 | Examples: 31 | ```py 32 | >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline 33 | >>> import torch 34 | 35 | >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") 36 | >>> pipe_prior.to("cuda") 37 | 38 | >>> prompt = "red cat, 4k photo" 39 | >>> out = pipe_prior(prompt) 40 | >>> image_emb = out.image_embeds 41 | >>> negative_image_emb = out.negative_image_embeds 42 | 43 | >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") 44 | >>> pipe.to("cuda") 45 | 46 | >>> image = pipe( 47 | ... prompt, 48 | ... image_embeds=image_emb, 49 | ... negative_image_embeds=negative_image_emb, 50 | ... height=768, 51 | ... width=768, 52 | ... num_inference_steps=100, 53 | ... ).images 54 | 55 | >>> image[0].save("cat.png") 56 | ``` 57 | """ 58 | 59 | EXAMPLE_INTERPOLATE_DOC_STRING = """ 60 | Examples: 61 | ```py 62 | >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline 63 | >>> from diffusers.utils import load_image 64 | >>> import PIL 65 | 66 | >>> import torch 67 | >>> from torchvision import transforms 68 | 69 | >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( 70 | ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 71 | ... ) 72 | >>> pipe_prior.to("cuda") 73 | 74 | >>> img1 = load_image( 75 | ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" 76 | ... "/kandinsky/cat.png" 77 | ... ) 78 | 79 | >>> img2 = load_image( 80 | ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" 81 | ... "/kandinsky/starry_night.jpeg" 82 | ... ) 83 | 84 | >>> images_texts = ["a cat", img1, img2] 85 | >>> weights = [0.3, 0.3, 0.4] 86 | >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) 87 | 88 | >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) 89 | >>> pipe.to("cuda") 90 | 91 | >>> image = pipe( 92 | ... "", 93 | ... image_embeds=image_emb, 94 | ... negative_image_embeds=zero_image_emb, 95 | ... height=768, 96 | ... width=768, 97 | ... num_inference_steps=150, 98 | ... ).images[0] 99 | 100 | >>> image.save("starry_cat.png") 101 | ``` 102 | """ 103 | 104 | 105 | @dataclass 106 | class KandinskyPriorPipelineOutput(BaseOutput): 107 | """ 108 | Output class for KandinskyPriorPipeline. 109 | 110 | Args: 111 | image_embeds (`torch.FloatTensor`) 112 | clip image embeddings for text prompt 113 | negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) 114 | clip image embeddings for unconditional tokens 115 | """ 116 | 117 | image_embeds: Union[torch.FloatTensor, np.ndarray] 118 | negative_image_embeds: Union[torch.FloatTensor, np.ndarray] 119 | 120 | 121 | class KandinskyPriorPipeline(DiffusionPipeline): 122 | """ 123 | Pipeline for generating image prior for Kandinsky 124 | 125 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 126 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 127 | 128 | Args: 129 | prior ([`PriorTransformer`]): 130 | The canonincal unCLIP prior to approximate the image embedding from the text embedding. 131 | image_encoder ([`CLIPVisionModelWithProjection`]): 132 | Frozen image-encoder. 133 | text_encoder ([`CLIPTextModelWithProjection`]): 134 | Frozen text-encoder. 135 | tokenizer (`CLIPTokenizer`): 136 | Tokenizer of class 137 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 138 | scheduler ([`UnCLIPScheduler`]): 139 | A scheduler to be used in combination with `prior` to generate image embedding. 140 | """ 141 | 142 | _exclude_from_cpu_offload = ["prior"] 143 | 144 | def __init__( 145 | self, 146 | prior: PriorTransformer, 147 | image_encoder: CLIPVisionModelWithProjection, 148 | text_encoder: CLIPTextModelWithProjection, 149 | tokenizer: CLIPTokenizer, 150 | scheduler: UnCLIPScheduler, 151 | image_processor: CLIPImageProcessor, 152 | ): 153 | super().__init__() 154 | 155 | self.register_modules( 156 | prior=prior, 157 | text_encoder=text_encoder, 158 | tokenizer=tokenizer, 159 | scheduler=scheduler, 160 | image_encoder=image_encoder, 161 | image_processor=image_processor, 162 | ) 163 | 164 | @torch.no_grad() 165 | @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) 166 | def interpolate( 167 | self, 168 | images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], 169 | weights: List[float], 170 | num_images_per_prompt: int = 1, 171 | num_inference_steps: int = 25, 172 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 173 | latents: Optional[torch.FloatTensor] = None, 174 | negative_prior_prompt: Optional[str] = None, 175 | negative_prompt: str = "", 176 | guidance_scale: float = 4.0, 177 | device=None, 178 | ): 179 | """ 180 | Function invoked when using the prior pipeline for interpolation. 181 | 182 | Args: 183 | images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): 184 | list of prompts and images to guide the image generation. 185 | weights: (`List[float]`): 186 | list of weights for each condition in `images_and_prompts` 187 | num_images_per_prompt (`int`, *optional*, defaults to 1): 188 | The number of images to generate per prompt. 189 | num_inference_steps (`int`, *optional*, defaults to 25): 190 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 191 | expense of slower inference. 192 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 193 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 194 | to make generation deterministic. 195 | latents (`torch.FloatTensor`, *optional*): 196 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 197 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 198 | tensor will ge generated by sampling using the supplied random `generator`. 199 | negative_prior_prompt (`str`, *optional*): 200 | The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if 201 | `guidance_scale` is less than `1`). 202 | negative_prompt (`str` or `List[str]`, *optional*): 203 | The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if 204 | `guidance_scale` is less than `1`). 205 | guidance_scale (`float`, *optional*, defaults to 4.0): 206 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 207 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 208 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 209 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 210 | usually at the expense of lower image quality. 211 | 212 | Examples: 213 | 214 | Returns: 215 | [`KandinskyPriorPipelineOutput`] or `tuple` 216 | """ 217 | 218 | device = device or self.device 219 | 220 | if len(images_and_prompts) != len(weights): 221 | raise ValueError( 222 | f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" 223 | ) 224 | 225 | image_embeddings = [] 226 | for cond, weight in zip(images_and_prompts, weights): 227 | if isinstance(cond, str): 228 | image_emb = self( 229 | cond, 230 | num_inference_steps=num_inference_steps, 231 | num_images_per_prompt=num_images_per_prompt, 232 | generator=generator, 233 | latents=latents, 234 | negative_prompt=negative_prior_prompt, 235 | guidance_scale=guidance_scale, 236 | ).image_embeds 237 | 238 | elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): 239 | if isinstance(cond, PIL.Image.Image): 240 | cond = ( 241 | self.image_processor(cond, return_tensors="pt") 242 | .pixel_values[0] 243 | .unsqueeze(0) 244 | .to(dtype=self.image_encoder.dtype, device=device) 245 | ) 246 | 247 | image_emb = self.image_encoder(cond)["image_embeds"] 248 | 249 | else: 250 | raise ValueError( 251 | f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" 252 | ) 253 | 254 | image_embeddings.append(image_emb * weight) 255 | 256 | image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) 257 | 258 | out_zero = self( 259 | negative_prompt, 260 | num_inference_steps=num_inference_steps, 261 | num_images_per_prompt=num_images_per_prompt, 262 | generator=generator, 263 | latents=latents, 264 | negative_prompt=negative_prior_prompt, 265 | guidance_scale=guidance_scale, 266 | ) 267 | zero_image_emb = ( 268 | out_zero.negative_image_embeds 269 | if negative_prompt == "" 270 | else out_zero.image_embeds 271 | ) 272 | 273 | return KandinskyPriorPipelineOutput( 274 | image_embeds=image_emb, negative_image_embeds=zero_image_emb 275 | ) 276 | 277 | # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents 278 | def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): 279 | if latents is None: 280 | latents = randn_tensor( 281 | shape, generator=generator, device=device, dtype=dtype 282 | ) 283 | else: 284 | if latents.shape != shape: 285 | raise ValueError( 286 | f"Unexpected latents shape, got {latents.shape}, expected {shape}" 287 | ) 288 | latents = latents.to(device) 289 | 290 | latents = latents * scheduler.init_noise_sigma 291 | return latents 292 | 293 | def get_zero_embed(self, batch_size=1, device=None): 294 | device = device or self.device 295 | zero_img = torch.zeros( 296 | 1, 297 | 3, 298 | self.image_encoder.config.image_size, 299 | self.image_encoder.config.image_size, 300 | ).to(device=device, dtype=self.image_encoder.dtype) 301 | zero_image_emb = self.image_encoder(zero_img)["image_embeds"] 302 | zero_image_emb = zero_image_emb.repeat(batch_size, 1) 303 | return zero_image_emb 304 | 305 | def _encode_prompt( 306 | self, 307 | prompt, 308 | device, 309 | num_images_per_prompt, 310 | do_classifier_free_guidance, 311 | negative_prompt=None, 312 | ): 313 | batch_size = len(prompt) if isinstance(prompt, list) else 1 314 | # get prompt text embeddings 315 | text_inputs = self.tokenizer( 316 | prompt, 317 | padding="max_length", 318 | max_length=self.tokenizer.model_max_length, 319 | truncation=True, 320 | return_tensors="pt", 321 | ) 322 | text_input_ids = text_inputs.input_ids 323 | text_mask = text_inputs.attention_mask.bool().to(device) 324 | 325 | untruncated_ids = self.tokenizer( 326 | prompt, padding="longest", return_tensors="pt" 327 | ).input_ids 328 | 329 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 330 | text_input_ids, untruncated_ids 331 | ): 332 | removed_text = self.tokenizer.batch_decode( 333 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 334 | ) 335 | logger.warning( 336 | "The following part of your input was truncated because CLIP can only handle sequences up to" 337 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 338 | ) 339 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 340 | 341 | text_encoder_output = self.text_encoder(text_input_ids.to(device)) 342 | 343 | prompt_embeds = text_encoder_output.text_embeds 344 | text_encoder_hidden_states = text_encoder_output.last_hidden_state 345 | 346 | prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) 347 | text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave( 348 | num_images_per_prompt, dim=0 349 | ) 350 | text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) 351 | 352 | if do_classifier_free_guidance: 353 | uncond_tokens: List[str] 354 | if negative_prompt is None: 355 | uncond_tokens = [""] * batch_size 356 | elif type(prompt) is not type(negative_prompt): 357 | raise TypeError( 358 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 359 | f" {type(prompt)}." 360 | ) 361 | elif isinstance(negative_prompt, str): 362 | uncond_tokens = [negative_prompt] 363 | elif batch_size != len(negative_prompt): 364 | raise ValueError( 365 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 366 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 367 | " the batch size of `prompt`." 368 | ) 369 | else: 370 | uncond_tokens = negative_prompt 371 | 372 | uncond_input = self.tokenizer( 373 | uncond_tokens, 374 | padding="max_length", 375 | max_length=self.tokenizer.model_max_length, 376 | truncation=True, 377 | return_tensors="pt", 378 | ) 379 | uncond_text_mask = uncond_input.attention_mask.bool().to(device) 380 | negative_prompt_embeds_text_encoder_output = self.text_encoder( 381 | uncond_input.input_ids.to(device) 382 | ) 383 | 384 | negative_prompt_embeds = ( 385 | negative_prompt_embeds_text_encoder_output.text_embeds 386 | ) 387 | uncond_text_encoder_hidden_states = ( 388 | negative_prompt_embeds_text_encoder_output.last_hidden_state 389 | ) 390 | 391 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 392 | 393 | seq_len = negative_prompt_embeds.shape[1] 394 | negative_prompt_embeds = negative_prompt_embeds.repeat( 395 | 1, num_images_per_prompt 396 | ) 397 | negative_prompt_embeds = negative_prompt_embeds.view( 398 | batch_size * num_images_per_prompt, seq_len 399 | ) 400 | 401 | seq_len = uncond_text_encoder_hidden_states.shape[1] 402 | uncond_text_encoder_hidden_states = ( 403 | uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) 404 | ) 405 | uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( 406 | batch_size * num_images_per_prompt, seq_len, -1 407 | ) 408 | uncond_text_mask = uncond_text_mask.repeat_interleave( 409 | num_images_per_prompt, dim=0 410 | ) 411 | 412 | # done duplicates 413 | 414 | # For classifier free guidance, we need to do two forward passes. 415 | # Here we concatenate the unconditional and text embeddings into a single batch 416 | # to avoid doing two forward passes 417 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 418 | text_encoder_hidden_states = torch.cat( 419 | [uncond_text_encoder_hidden_states, text_encoder_hidden_states] 420 | ) 421 | 422 | text_mask = torch.cat([uncond_text_mask, text_mask]) 423 | 424 | return prompt_embeds, text_encoder_hidden_states, text_mask 425 | 426 | def enable_model_cpu_offload(self, gpu_id=0): 427 | r""" 428 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 429 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 430 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with 431 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. 432 | """ 433 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): 434 | from accelerate import cpu_offload_with_hook 435 | else: 436 | raise ImportError( 437 | "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." 438 | ) 439 | 440 | device = torch.device(f"cuda:{gpu_id}") 441 | 442 | if self.device.type != "cpu": 443 | self.to("cpu", silence_dtype_warnings=True) 444 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 445 | 446 | hook = None 447 | for cpu_offloaded_model in [self.text_encoder, self.prior]: 448 | _, hook = cpu_offload_with_hook( 449 | cpu_offloaded_model, device, prev_module_hook=hook 450 | ) 451 | 452 | # We'll offload the last model manually. 453 | self.prior_hook = hook 454 | 455 | _, hook = cpu_offload_with_hook( 456 | self.image_encoder, device, prev_module_hook=self.prior_hook 457 | ) 458 | 459 | self.final_offload_hook = hook 460 | 461 | @torch.no_grad() 462 | @replace_example_docstring(EXAMPLE_DOC_STRING) 463 | def __call__( 464 | self, 465 | prompt: Union[str, List[str]], 466 | negative_prompt: Optional[Union[str, List[str]]] = None, 467 | num_images_per_prompt: int = 1, 468 | num_inference_steps: int = 25, 469 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 470 | latents: Optional[torch.FloatTensor] = None, 471 | guidance_scale: float = 4.0, 472 | output_type: Optional[str] = "pt", 473 | return_dict: bool = True, 474 | ): 475 | """ 476 | Function invoked when calling the pipeline for generation. 477 | 478 | Args: 479 | prompt (`str` or `List[str]`): 480 | The prompt or prompts to guide the image generation. 481 | negative_prompt (`str` or `List[str]`, *optional*): 482 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 483 | if `guidance_scale` is less than `1`). 484 | num_images_per_prompt (`int`, *optional*, defaults to 1): 485 | The number of images to generate per prompt. 486 | num_inference_steps (`int`, *optional*, defaults to 25): 487 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 488 | expense of slower inference. 489 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 490 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 491 | to make generation deterministic. 492 | latents (`torch.FloatTensor`, *optional*): 493 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 494 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 495 | tensor will ge generated by sampling using the supplied random `generator`. 496 | guidance_scale (`float`, *optional*, defaults to 4.0): 497 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 498 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 499 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 500 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 501 | usually at the expense of lower image quality. 502 | output_type (`str`, *optional*, defaults to `"pt"`): 503 | The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` 504 | (`torch.Tensor`). 505 | return_dict (`bool`, *optional*, defaults to `True`): 506 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 507 | 508 | Examples: 509 | 510 | Returns: 511 | [`KandinskyPriorPipelineOutput`] or `tuple` 512 | """ 513 | 514 | if isinstance(prompt, str): 515 | prompt = [prompt] 516 | elif not isinstance(prompt, list): 517 | raise ValueError( 518 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 519 | ) 520 | 521 | if isinstance(negative_prompt, str): 522 | negative_prompt = [negative_prompt] 523 | elif not isinstance(negative_prompt, list) and negative_prompt is not None: 524 | raise ValueError( 525 | f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}" 526 | ) 527 | 528 | # if the negative prompt is defined we double the batch size to 529 | # directly retrieve the negative prompt embedding 530 | if negative_prompt is not None: 531 | prompt = prompt + negative_prompt 532 | negative_prompt = 2 * negative_prompt 533 | 534 | device = self._execution_device 535 | 536 | batch_size = len(prompt) 537 | batch_size = batch_size * num_images_per_prompt 538 | 539 | prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( 540 | prompt, device, num_images_per_prompt, False, negative_prompt 541 | ) 542 | 543 | hidden_states = randn_tensor( 544 | (batch_size, prompt_embeds.shape[-1]), 545 | device=prompt_embeds.device, 546 | dtype=prompt_embeds.dtype, 547 | generator=generator, 548 | ) 549 | 550 | latents = self.prior( 551 | hidden_states, 552 | proj_embedding=prompt_embeds, 553 | encoder_hidden_states=text_encoder_hidden_states, 554 | attention_mask=text_mask, 555 | ).predicted_image_embedding 556 | 557 | image_embeddings = latents 558 | 559 | # if negative prompt has been defined, we retrieve split the image embedding into two 560 | if negative_prompt is None: 561 | zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) 562 | 563 | if ( 564 | hasattr(self, "final_offload_hook") 565 | and self.final_offload_hook is not None 566 | ): 567 | self.final_offload_hook.offload() 568 | else: 569 | image_embeddings, zero_embeds = image_embeddings.chunk(2) 570 | 571 | if ( 572 | hasattr(self, "final_offload_hook") 573 | and self.final_offload_hook is not None 574 | ): 575 | self.prior_hook.offload() 576 | 577 | if output_type not in ["pt", "np"]: 578 | raise ValueError( 579 | f"Only the output types `pt` and `np` are supported not output_type={output_type}" 580 | ) 581 | 582 | if output_type == "np": 583 | image_embeddings = image_embeddings.cpu().numpy() 584 | zero_embeds = zero_embeds.cpu().numpy() 585 | 586 | if not return_dict: 587 | return (image_embeddings, zero_embeds) 588 | 589 | return KandinskyPriorPipelineOutput( 590 | image_embeds=image_embeddings, negative_image_embeds=zero_embeds 591 | ) 592 | -------------------------------------------------------------------------------- /src/pipelines/pipeline_unclip.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("..") 4 | 5 | import inspect 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import torch 9 | from torch.nn import functional as F 10 | from transformers import CLIPTextModelWithProjection, CLIPTokenizer 11 | from transformers.models.clip.modeling_clip import CLIPTextModelOutput 12 | 13 | from diffusers.models import UNet2DConditionModel, UNet2DModel 14 | from diffusers.schedulers import UnCLIPScheduler 15 | from diffusers.utils import logging, randn_tensor 16 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 17 | from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel 18 | 19 | 20 | from diffusers.models import PriorTransformer 21 | 22 | 23 | import torch 24 | from torchvision.transforms import ToPILImage 25 | 26 | import copy 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | class UnCLIPPipeline(DiffusionPipeline): 32 | """ 33 | Pipeline for text-to-image generation using unCLIP. 34 | 35 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 36 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 37 | 38 | Args: 39 | text_encoder ([`~transformers.CLIPTextModelWithProjection`]): 40 | Frozen text-encoder. 41 | tokenizer ([`~transformers.CLIPTokenizer`]): 42 | A `CLIPTokenizer` to tokenize text. 43 | prior ([`PriorTransformer`]): 44 | The canonical unCLIP prior to approximate the image embedding from the text embedding. 45 | text_proj ([`UnCLIPTextProjModel`]): 46 | Utility class to prepare and combine the embeddings before they are passed to the decoder. 47 | decoder ([`UNet2DConditionModel`]): 48 | The decoder to invert the image embedding into an image. 49 | super_res_first ([`UNet2DModel`]): 50 | Super resolution UNet. Used in all but the last step of the super resolution diffusion process. 51 | super_res_last ([`UNet2DModel`]): 52 | Super resolution UNet. Used in the last step of the super resolution diffusion process. 53 | prior_scheduler ([`UnCLIPScheduler`]): 54 | Scheduler used in the prior denoising process (a modified [`DDPMScheduler`]). 55 | decoder_scheduler ([`UnCLIPScheduler`]): 56 | Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). 57 | super_res_scheduler ([`UnCLIPScheduler`]): 58 | Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). 59 | 60 | """ 61 | 62 | _exclude_from_cpu_offload = ["prior"] 63 | 64 | prior: PriorTransformer 65 | decoder: UNet2DConditionModel 66 | text_proj: UnCLIPTextProjModel 67 | text_encoder: CLIPTextModelWithProjection 68 | tokenizer: CLIPTokenizer 69 | super_res_first: UNet2DModel 70 | super_res_last: UNet2DModel 71 | 72 | prior_scheduler: UnCLIPScheduler 73 | decoder_scheduler: UnCLIPScheduler 74 | super_res_scheduler: UnCLIPScheduler 75 | 76 | def __init__( 77 | self, 78 | prior: PriorTransformer, 79 | decoder: UNet2DConditionModel, 80 | text_encoder: CLIPTextModelWithProjection, 81 | tokenizer: CLIPTokenizer, 82 | text_proj: UnCLIPTextProjModel, 83 | super_res_first: UNet2DModel, 84 | super_res_last: UNet2DModel, 85 | prior_scheduler: UnCLIPScheduler, 86 | decoder_scheduler: UnCLIPScheduler, 87 | super_res_scheduler: UnCLIPScheduler, 88 | ): 89 | super().__init__() 90 | 91 | self.register_modules( 92 | prior=prior, 93 | decoder=decoder, 94 | text_encoder=text_encoder, 95 | tokenizer=tokenizer, 96 | text_proj=text_proj, 97 | super_res_first=super_res_first, 98 | super_res_last=super_res_last, 99 | prior_scheduler=prior_scheduler, 100 | decoder_scheduler=decoder_scheduler, 101 | super_res_scheduler=super_res_scheduler, 102 | ) 103 | 104 | def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): 105 | if latents is None: 106 | latents = randn_tensor( 107 | shape, generator=generator, device=device, dtype=dtype 108 | ) 109 | else: 110 | if latents.shape != shape: 111 | raise ValueError( 112 | f"Unexpected latents shape, got {latents.shape}, expected {shape}" 113 | ) 114 | latents = latents.to(device) 115 | 116 | latents = latents * scheduler.init_noise_sigma 117 | return latents 118 | 119 | def _encode_prompt( 120 | self, 121 | prompt, 122 | device, 123 | num_images_per_prompt, 124 | do_classifier_free_guidance, 125 | text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, 126 | text_attention_mask: Optional[torch.Tensor] = None, 127 | ): 128 | if text_model_output is None: 129 | batch_size = len(prompt) if isinstance(prompt, list) else 1 130 | # get prompt text embeddings 131 | text_inputs = self.tokenizer( 132 | prompt, 133 | padding="max_length", 134 | max_length=self.tokenizer.model_max_length, 135 | truncation=True, 136 | return_tensors="pt", 137 | ) 138 | text_input_ids = text_inputs.input_ids 139 | text_mask = text_inputs.attention_mask.bool().to(device) 140 | 141 | untruncated_ids = self.tokenizer( 142 | prompt, padding="longest", return_tensors="pt" 143 | ).input_ids 144 | 145 | if untruncated_ids.shape[-1] >= text_input_ids.shape[ 146 | -1 147 | ] and not torch.equal(text_input_ids, untruncated_ids): 148 | removed_text = self.tokenizer.batch_decode( 149 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 150 | ) 151 | logger.warning( 152 | "The following part of your input was truncated because CLIP can only handle sequences up to" 153 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 154 | ) 155 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 156 | 157 | text_encoder_output = self.text_encoder(text_input_ids.to(device)) 158 | 159 | prompt_embeds = text_encoder_output.text_embeds 160 | text_encoder_hidden_states = text_encoder_output.last_hidden_state 161 | 162 | else: 163 | batch_size = text_model_output[0].shape[0] 164 | prompt_embeds, text_encoder_hidden_states = ( 165 | text_model_output[0], 166 | text_model_output[1], 167 | ) 168 | text_mask = text_attention_mask 169 | 170 | prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) 171 | text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave( 172 | num_images_per_prompt, dim=0 173 | ) 174 | text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) 175 | 176 | if do_classifier_free_guidance: 177 | uncond_tokens = [""] * batch_size 178 | 179 | uncond_input = self.tokenizer( 180 | uncond_tokens, 181 | padding="max_length", 182 | max_length=self.tokenizer.model_max_length, 183 | truncation=True, 184 | return_tensors="pt", 185 | ) 186 | uncond_text_mask = uncond_input.attention_mask.bool().to(device) 187 | negative_prompt_embeds_text_encoder_output = self.text_encoder( 188 | uncond_input.input_ids.to(device) 189 | ) 190 | 191 | negative_prompt_embeds = ( 192 | negative_prompt_embeds_text_encoder_output.text_embeds 193 | ) 194 | uncond_text_encoder_hidden_states = ( 195 | negative_prompt_embeds_text_encoder_output.last_hidden_state 196 | ) 197 | 198 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 199 | 200 | seq_len = negative_prompt_embeds.shape[1] 201 | negative_prompt_embeds = negative_prompt_embeds.repeat( 202 | 1, num_images_per_prompt 203 | ) 204 | negative_prompt_embeds = negative_prompt_embeds.view( 205 | batch_size * num_images_per_prompt, seq_len 206 | ) 207 | 208 | seq_len = uncond_text_encoder_hidden_states.shape[1] 209 | uncond_text_encoder_hidden_states = ( 210 | uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) 211 | ) 212 | uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( 213 | batch_size * num_images_per_prompt, seq_len, -1 214 | ) 215 | uncond_text_mask = uncond_text_mask.repeat_interleave( 216 | num_images_per_prompt, dim=0 217 | ) 218 | 219 | # done duplicates 220 | 221 | # For classifier free guidance, we need to do two forward passes. 222 | # Here we concatenate the unconditional and text embeddings into a single batch 223 | # to avoid doing two forward passes 224 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 225 | text_encoder_hidden_states = torch.cat( 226 | [uncond_text_encoder_hidden_states, text_encoder_hidden_states] 227 | ) 228 | 229 | text_mask = torch.cat([uncond_text_mask, text_mask]) 230 | 231 | return prompt_embeds, text_encoder_hidden_states, text_mask 232 | 233 | @torch.no_grad() 234 | def __call__( 235 | self, 236 | prompt: Optional[Union[str, List[str]]] = None, 237 | num_images_per_prompt: int = 1, 238 | prior_num_inference_steps: int = 25, 239 | decoder_num_inference_steps: int = 25, 240 | super_res_num_inference_steps: int = 7, 241 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 242 | prior_latents: Optional[torch.FloatTensor] = None, 243 | decoder_latents: Optional[torch.FloatTensor] = None, 244 | super_res_latents: Optional[torch.FloatTensor] = None, 245 | text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, 246 | text_attention_mask: Optional[torch.Tensor] = None, 247 | prior_guidance_scale: float = 4.0, 248 | decoder_guidance_scale: float = 8.0, 249 | output_type: Optional[str] = "pil", 250 | return_dict: bool = True, 251 | null_prompt_decoder: bool = False, 252 | ): 253 | """ 254 | The call function to the pipeline for generation. 255 | 256 | Args: 257 | prompt (`str` or `List[str]`): 258 | The prompt or prompts to guide image generation. This can only be left undefined if `text_model_output` 259 | and `text_attention_mask` is passed. 260 | num_images_per_prompt (`int`, *optional*, defaults to 1): 261 | The number of images to generate per prompt. 262 | prior_num_inference_steps (`int`, *optional*, defaults to 25): 263 | The number of denoising steps for the prior. More denoising steps usually lead to a higher quality 264 | image at the expense of slower inference. 265 | decoder_num_inference_steps (`int`, *optional*, defaults to 25): 266 | The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality 267 | image at the expense of slower inference. 268 | super_res_num_inference_steps (`int`, *optional*, defaults to 7): 269 | The number of denoising steps for super resolution. More denoising steps usually lead to a higher 270 | quality image at the expense of slower inference. 271 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 272 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 273 | generation deterministic. 274 | prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): 275 | Pre-generated noisy latents to be used as inputs for the prior. 276 | decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): 277 | Pre-generated noisy latents to be used as inputs for the decoder. 278 | super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): 279 | Pre-generated noisy latents to be used as inputs for the decoder. 280 | prior_guidance_scale (`float`, *optional*, defaults to 4.0): 281 | A higher guidance scale value encourages the model to generate images closely linked to the text 282 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 283 | decoder_guidance_scale (`float`, *optional*, defaults to 4.0): 284 | A higher guidance scale value encourages the model to generate images closely linked to the text 285 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 286 | text_model_output (`CLIPTextModelOutput`, *optional*): 287 | Pre-defined [`CLIPTextModel`] outputs that can be derived from the text encoder. Pre-defined text 288 | outputs can be passed for tasks like text embedding interpolations. Make sure to also pass 289 | `text_attention_mask` in this case. `prompt` can the be left `None`. 290 | text_attention_mask (`torch.Tensor`, *optional*): 291 | Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention 292 | masks are necessary when passing `text_model_output`. 293 | output_type (`str`, *optional*, defaults to `"pil"`): 294 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 295 | return_dict (`bool`, *optional*, defaults to `True`): 296 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 297 | 298 | Returns: 299 | [`~pipelines.ImagePipelineOutput`] or `tuple`: 300 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is 301 | returned where the first element is a list with the generated images. 302 | """ 303 | if prompt is not None: 304 | if isinstance(prompt, str): 305 | batch_size = 1 306 | elif isinstance(prompt, list): 307 | batch_size = len(prompt) 308 | else: 309 | raise ValueError( 310 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 311 | ) 312 | else: 313 | batch_size = text_model_output[0].shape[0] 314 | 315 | device = self._execution_device 316 | 317 | batch_size = batch_size * num_images_per_prompt 318 | 319 | prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( 320 | prompt, 321 | device, 322 | num_images_per_prompt, 323 | False, 324 | text_model_output, 325 | text_attention_mask, 326 | ) 327 | 328 | hidden_states = randn_tensor( 329 | (batch_size, prompt_embeds.shape[-1]), 330 | device=prompt_embeds.device, 331 | dtype=prompt_embeds.dtype, 332 | generator=generator, 333 | ) 334 | 335 | prior_latents = self.prior( 336 | hidden_states, 337 | proj_embedding=prompt_embeds, 338 | encoder_hidden_states=text_encoder_hidden_states, 339 | attention_mask=text_mask, 340 | ).predicted_image_embedding 341 | 342 | do_classifier_free_guidance = decoder_guidance_scale > 1.0 343 | prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( 344 | prompt if not null_prompt_decoder else "", 345 | device, 346 | num_images_per_prompt, 347 | do_classifier_free_guidance, 348 | text_model_output, 349 | text_attention_mask, 350 | ) 351 | 352 | prior_latents = prior_latents.expand( 353 | ( 354 | prompt_embeds.shape[0] // 2 355 | if do_classifier_free_guidance 356 | else prompt_embeds.shape[0], 357 | prompt_embeds.shape[1], 358 | ) 359 | ) 360 | image_embeddings = prior_latents.clone() 361 | # return image_embeddings 362 | 363 | # decoder 364 | text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( 365 | image_embeddings=image_embeddings, 366 | prompt_embeds=prompt_embeds, 367 | text_encoder_hidden_states=text_encoder_hidden_states, 368 | do_classifier_free_guidance=do_classifier_free_guidance, 369 | ) 370 | 371 | if device.type == "mps": 372 | # HACK: MPS: There is a panic when padding bool tensors, 373 | # so cast to int tensor for the pad and back to bool afterwards 374 | text_mask = text_mask.type(torch.int) 375 | decoder_text_mask = F.pad( 376 | text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1 377 | ) 378 | decoder_text_mask = decoder_text_mask.type(torch.bool) 379 | else: 380 | decoder_text_mask = F.pad( 381 | text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True 382 | ) 383 | 384 | self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) 385 | decoder_timesteps_tensor = self.decoder_scheduler.timesteps 386 | 387 | num_channels_latents = self.decoder.config.in_channels 388 | height = self.decoder.config.sample_size 389 | width = self.decoder.config.sample_size 390 | 391 | decoder_latents = self.prepare_latents( 392 | (batch_size, num_channels_latents, height, width), 393 | text_encoder_hidden_states.dtype, 394 | device, 395 | generator, 396 | decoder_latents, 397 | self.decoder_scheduler, 398 | ) 399 | 400 | for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): 401 | # expand the latents if we are doing classifier free guidance 402 | latent_model_input = ( 403 | torch.cat([decoder_latents] * 2) 404 | if do_classifier_free_guidance 405 | else decoder_latents 406 | ) 407 | 408 | noise_pred = self.decoder( 409 | sample=latent_model_input, 410 | timestep=t, 411 | encoder_hidden_states=text_encoder_hidden_states, 412 | class_labels=additive_clip_time_embeddings, 413 | attention_mask=decoder_text_mask, 414 | ).sample 415 | 416 | if do_classifier_free_guidance: 417 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 418 | noise_pred_uncond, _ = noise_pred_uncond.split( 419 | latent_model_input.shape[1], dim=1 420 | ) 421 | noise_pred_text, predicted_variance = noise_pred_text.split( 422 | latent_model_input.shape[1], dim=1 423 | ) 424 | noise_pred = noise_pred_uncond + decoder_guidance_scale * ( 425 | noise_pred_text - noise_pred_uncond 426 | ) 427 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) 428 | 429 | if i + 1 == decoder_timesteps_tensor.shape[0]: 430 | prev_timestep = None 431 | else: 432 | prev_timestep = decoder_timesteps_tensor[i + 1] 433 | 434 | # compute the previous noisy sample x_t -> x_t-1 435 | decoder_latents = self.decoder_scheduler.step( 436 | noise_pred, 437 | t, 438 | decoder_latents, 439 | prev_timestep=prev_timestep, 440 | generator=generator, 441 | ).prev_sample 442 | 443 | decoder_latents = decoder_latents.clamp(-1, 1) 444 | 445 | image_small = decoder_latents 446 | 447 | # done decoder 448 | 449 | # super res 450 | 451 | self.super_res_scheduler.set_timesteps( 452 | super_res_num_inference_steps, device=device 453 | ) 454 | super_res_timesteps_tensor = self.super_res_scheduler.timesteps 455 | 456 | channels = self.super_res_first.config.in_channels // 2 457 | height = self.super_res_first.config.sample_size 458 | width = self.super_res_first.config.sample_size 459 | 460 | super_res_latents = self.prepare_latents( 461 | (batch_size, channels, height, width), 462 | image_small.dtype, 463 | device, 464 | generator, 465 | super_res_latents, 466 | self.super_res_scheduler, 467 | ) 468 | 469 | if device.type == "mps": 470 | # MPS does not support many interpolations 471 | image_upscaled = F.interpolate(image_small, size=[height, width]) 472 | else: 473 | interpolate_antialias = {} 474 | if "antialias" in inspect.signature(F.interpolate).parameters: 475 | interpolate_antialias["antialias"] = True 476 | 477 | image_upscaled = F.interpolate( 478 | image_small, 479 | size=[height, width], 480 | mode="bicubic", 481 | align_corners=False, 482 | **interpolate_antialias, 483 | ) 484 | 485 | for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): 486 | # no classifier free guidance 487 | 488 | if i == super_res_timesteps_tensor.shape[0] - 1: 489 | unet = self.super_res_last 490 | else: 491 | unet = self.super_res_first 492 | 493 | latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) 494 | 495 | noise_pred = unet( 496 | sample=latent_model_input, 497 | timestep=t, 498 | ).sample 499 | 500 | if i + 1 == super_res_timesteps_tensor.shape[0]: 501 | prev_timestep = None 502 | else: 503 | prev_timestep = super_res_timesteps_tensor[i + 1] 504 | 505 | # compute the previous noisy sample x_t -> x_t-1 506 | super_res_latents = self.super_res_scheduler.step( 507 | noise_pred, 508 | t, 509 | super_res_latents, 510 | prev_timestep=prev_timestep, 511 | generator=generator, 512 | ).prev_sample 513 | 514 | image = super_res_latents 515 | # done super res 516 | 517 | # post processing 518 | 519 | image = image * 0.5 + 0.5 520 | image = image.clamp(0, 1) 521 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 522 | 523 | if output_type == "pil": 524 | image = self.numpy_to_pil(image) 525 | 526 | if not return_dict: 527 | return (image,) 528 | 529 | return ImagePipelineOutput(images=image) 530 | -------------------------------------------------------------------------------- /src/priors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eclipse-t2i/eclipse-inference/57814633309b8057220bd0f6bdb1ffdc98a2979b/src/priors/__init__.py -------------------------------------------------------------------------------- /src/priors/prior_transformer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | from dataclasses import dataclass 5 | from typing import Dict, Optional, Union 6 | 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | from diffusers.configuration_utils import ConfigMixin, register_to_config 13 | from diffusers.utils import BaseOutput 14 | from diffusers.models.attention import BasicTransformerBlock 15 | from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor 16 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 17 | from diffusers.models.modeling_utils import ModelMixin 18 | 19 | 20 | @dataclass 21 | class PriorTransformerOutput(BaseOutput): 22 | """ 23 | The output of [`PriorTransformer`]. 24 | 25 | Args: 26 | predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): 27 | The predicted CLIP image embedding conditioned on the CLIP text embedding input. 28 | """ 29 | 30 | predicted_image_embedding: torch.FloatTensor 31 | 32 | 33 | class PriorTransformer(ModelMixin, ConfigMixin): 34 | """ 35 | A Prior Transformer model. 36 | 37 | Parameters: 38 | num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. 39 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 40 | num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. 41 | embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` 42 | num_embeddings (`int`, *optional*, defaults to 77): 43 | The number of embeddings of the model input `hidden_states` 44 | additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the 45 | projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + 46 | additional_embeddings`. 47 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 48 | time_embed_act_fn (`str`, *optional*, defaults to 'silu'): 49 | The activation function to use to create timestep embeddings. 50 | norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before 51 | passing to Transformer blocks. Set it to `None` if normalization is not needed. 52 | embedding_proj_norm_type (`str`, *optional*, defaults to None): 53 | The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not 54 | needed. 55 | encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): 56 | The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if 57 | `encoder_hidden_states` is `None`. 58 | added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. 59 | Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot 60 | product between the text embedding and image embedding as proposed in the unclip paper 61 | https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. 62 | time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. 63 | If None, will be set to `num_attention_heads * attention_head_dim` 64 | embedding_proj_dim (`int`, *optional*, default to None): 65 | The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. 66 | clip_embed_dim (`int`, *optional*, default to None): 67 | The dimension of the output. If None, will be set to `embedding_dim`. 68 | """ 69 | 70 | @register_to_config 71 | def __init__( 72 | self, 73 | num_attention_heads: int = 32, 74 | attention_head_dim: int = 64, 75 | num_layers: int = 20, 76 | embedding_dim: int = 768, 77 | num_embeddings=77, 78 | additional_embeddings=3, # as we have remvoed the time embedding 79 | dropout: float = 0.0, 80 | # time_embed_act_fn: str = "silu", 81 | norm_in_type: Optional[str] = None, # layer 82 | embedding_proj_norm_type: Optional[str] = None, # layer 83 | encoder_hid_proj_type: Optional[str] = "linear", # linear 84 | added_emb_type: Optional[str] = "prd", # prd 85 | # time_embed_dim: Optional[int] = None, 86 | embedding_proj_dim: Optional[int] = None, 87 | clip_embed_dim: Optional[int] = None, 88 | ): 89 | super().__init__() 90 | self.num_attention_heads = num_attention_heads 91 | self.attention_head_dim = attention_head_dim 92 | inner_dim = num_attention_heads * attention_head_dim 93 | self.additional_embeddings = additional_embeddings 94 | 95 | # time_embed_dim = time_embed_dim or inner_dim 96 | embedding_proj_dim = embedding_proj_dim or embedding_dim 97 | clip_embed_dim = clip_embed_dim or embedding_dim 98 | 99 | # self.time_proj = Timesteps(inner_dim, True, 0) 100 | # self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) 101 | 102 | self.proj_in = nn.Linear(embedding_dim, inner_dim) 103 | 104 | if embedding_proj_norm_type is None: 105 | self.embedding_proj_norm = None 106 | elif embedding_proj_norm_type == "layer": 107 | self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) 108 | else: 109 | raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") 110 | 111 | self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) 112 | 113 | if encoder_hid_proj_type is None: 114 | self.encoder_hidden_states_proj = None 115 | elif encoder_hid_proj_type == "linear": 116 | self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) 117 | else: 118 | raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") 119 | 120 | self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) 121 | 122 | if added_emb_type == "prd": 123 | self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) 124 | elif added_emb_type is None: 125 | self.prd_embedding = None 126 | else: 127 | raise ValueError( 128 | f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." 129 | ) 130 | 131 | self.transformer_blocks = nn.ModuleList( 132 | [ 133 | BasicTransformerBlock( 134 | inner_dim, 135 | num_attention_heads, 136 | attention_head_dim, 137 | dropout=dropout, 138 | activation_fn="gelu", 139 | attention_bias=True, 140 | ) 141 | for d in range(num_layers) 142 | ] 143 | ) 144 | 145 | if norm_in_type == "layer": 146 | self.norm_in = nn.LayerNorm(inner_dim) 147 | elif norm_in_type is None: 148 | self.norm_in = None 149 | else: 150 | raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") 151 | 152 | self.norm_out = nn.LayerNorm(inner_dim) 153 | 154 | self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) 155 | 156 | causal_attention_mask = torch.full( 157 | [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 158 | ) 159 | causal_attention_mask.triu_(1) 160 | causal_attention_mask = causal_attention_mask[None, ...] 161 | self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) 162 | 163 | self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) 164 | self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) 165 | 166 | @property 167 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors 168 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 169 | r""" 170 | Returns: 171 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 172 | indexed by its weight name. 173 | """ 174 | # set recursively 175 | processors = {} 176 | 177 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 178 | if hasattr(module, "set_processor"): 179 | processors[f"{name}.processor"] = module.processor 180 | 181 | for sub_name, child in module.named_children(): 182 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 183 | 184 | return processors 185 | 186 | for name, module in self.named_children(): 187 | fn_recursive_add_processors(name, module, processors) 188 | 189 | return processors 190 | 191 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor 192 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 193 | r""" 194 | Sets the attention processor to use to compute attention. 195 | 196 | Parameters: 197 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 198 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 199 | for **all** `Attention` layers. 200 | 201 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 202 | processor. This is strongly recommended when setting trainable attention processors. 203 | 204 | """ 205 | count = len(self.attn_processors.keys()) 206 | 207 | if isinstance(processor, dict) and len(processor) != count: 208 | raise ValueError( 209 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 210 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 211 | ) 212 | 213 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 214 | if hasattr(module, "set_processor"): 215 | if not isinstance(processor, dict): 216 | module.set_processor(processor) 217 | else: 218 | module.set_processor(processor.pop(f"{name}.processor")) 219 | 220 | for sub_name, child in module.named_children(): 221 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 222 | 223 | for name, module in self.named_children(): 224 | fn_recursive_attn_processor(name, module, processor) 225 | 226 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 227 | def set_default_attn_processor(self): 228 | """ 229 | Disables custom attention processors and sets the default attention implementation. 230 | """ 231 | self.set_attn_processor(AttnProcessor()) 232 | 233 | def forward( 234 | self, 235 | hidden_states, 236 | # timestep: Union[torch.Tensor, float, int], 237 | proj_embedding: torch.FloatTensor, 238 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 239 | attention_mask: Optional[torch.BoolTensor] = None, 240 | return_dict: bool = True, 241 | ): 242 | """ 243 | The [`PriorTransformer`] forward method. 244 | 245 | Args: 246 | hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): 247 | The currently predicted image embeddings. 248 | timestep (`torch.LongTensor`): 249 | Current denoising step. 250 | proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): 251 | Projected embedding vector the denoising process is conditioned on. 252 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): 253 | Hidden states of the text embeddings the denoising process is conditioned on. 254 | attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): 255 | Text mask for the text embeddings. 256 | return_dict (`bool`, *optional*, defaults to `True`): 257 | Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain 258 | tuple. 259 | 260 | Returns: 261 | [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: 262 | If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a 263 | tuple is returned where the first element is the sample tensor. 264 | """ 265 | batch_size = hidden_states.shape[0] 266 | 267 | # timesteps = timestep 268 | # if not torch.is_tensor(timesteps): 269 | # timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) 270 | # elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 271 | # timesteps = timesteps[None].to(hidden_states.device) 272 | 273 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 274 | # timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) 275 | 276 | # timesteps_projected = self.time_proj(timesteps) 277 | 278 | # timesteps does not contain any weights and will always return f32 tensors 279 | # but time_embedding might be fp16, so we need to cast here. 280 | # timesteps_projected = timesteps_projected.to(dtype=self.dtype) 281 | # time_embeddings = self.time_embedding(timesteps_projected) 282 | 283 | if self.embedding_proj_norm is not None: 284 | proj_embedding = self.embedding_proj_norm(proj_embedding) 285 | 286 | proj_embeddings = self.embedding_proj(proj_embedding) 287 | if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: 288 | encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) 289 | elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: 290 | raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") 291 | 292 | hidden_states = self.proj_in(hidden_states) 293 | 294 | positional_embeddings = self.positional_embedding.to(hidden_states.dtype) 295 | 296 | additional_embeds = [] 297 | additional_embeddings_len = 0 298 | 299 | if encoder_hidden_states is not None: 300 | additional_embeds.append(encoder_hidden_states) 301 | additional_embeddings_len += encoder_hidden_states.shape[1] 302 | 303 | if len(proj_embeddings.shape) == 2: 304 | proj_embeddings = proj_embeddings[:, None, :] 305 | 306 | if len(hidden_states.shape) == 2: 307 | hidden_states = hidden_states[:, None, :] 308 | 309 | additional_embeds = additional_embeds + [ 310 | proj_embeddings, 311 | # time_embeddings[:, None, :], 312 | hidden_states, 313 | ] 314 | 315 | if self.prd_embedding is not None: 316 | prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) 317 | additional_embeds.append(prd_embedding) 318 | 319 | hidden_states = torch.cat( 320 | additional_embeds, 321 | dim=1, 322 | ) 323 | 324 | # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens 325 | additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 326 | if positional_embeddings.shape[1] < hidden_states.shape[1]: 327 | positional_embeddings = F.pad( 328 | positional_embeddings, 329 | ( 330 | 0, 331 | 0, 332 | additional_embeddings_len, 333 | self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, 334 | ), 335 | value=0.0, 336 | ) 337 | 338 | hidden_states = hidden_states + positional_embeddings 339 | 340 | if attention_mask is not None: 341 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 342 | attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) 343 | attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) 344 | attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) 345 | 346 | if self.norm_in is not None: 347 | hidden_states = self.norm_in(hidden_states) 348 | 349 | for block in self.transformer_blocks: 350 | hidden_states = block(hidden_states, attention_mask=attention_mask) 351 | 352 | hidden_states = self.norm_out(hidden_states) 353 | 354 | if self.prd_embedding is not None: 355 | hidden_states = hidden_states[:, -1] 356 | else: 357 | hidden_states = hidden_states[:, additional_embeddings_len:] 358 | 359 | predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) 360 | 361 | if not return_dict: 362 | return (predicted_image_embedding,) 363 | 364 | return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) 365 | 366 | def post_process_latents(self, prior_latents): 367 | prior_latents = (prior_latents * self.clip_std) + self.clip_mean 368 | return prior_latents --------------------------------------------------------------------------------