├── .github └── workflows │ └── publish.yml ├── .gitignore ├── README.md ├── __init__.py ├── assets └── rgbx24_teaser.png ├── download_model.bat ├── pyproject.toml ├── requirements.txt ├── rgb2x ├── example │ └── Castlereagh_corridor_photo.png ├── gradio_demo_rgb2x.py ├── load_image.py └── pipeline_rgb2x.py ├── rgbx.py ├── x2rgb ├── example │ ├── kitchen-albedo.png │ ├── kitchen-irradiance.png │ ├── kitchen-metallic.png │ ├── kitchen-normal.png │ ├── kitchen-ref.png │ └── kitchen-roughness.png ├── gradio_demo_x2rgb.py ├── load_image.py └── pipeline_x2rgb.py └── x2rgb_inpainting ├── gradio_demo_x2rgb_inpainting.py ├── load_image.py └── pipeline_x2rgb_inpainting.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'toyxyz' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | - name: Publish Custom Node 23 | uses: Comfy-Org/publish-node-action@v1 24 | with: 25 | ## Add your own personal access token to your Github Repository secrets and reference it here. 26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .env 3 | .dev 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI_rgbx_Wrapper 2 | 3 | 4 | original project : https://github.com/zheng95z/rgbx 5 | 6 | This is the rgb2x wrapper node for ComfyUI. 7 | 8 | The required models are automatically downloaded on the first run. 9 | 10 | ![image](https://github.com/user-attachments/assets/8dc4f187-e02e-499e-a6b2-2bee4bdca569) 11 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | node_list = [ #Add list of .py files containing nodes here 5 | "rgbx" 6 | ] 7 | 8 | NODE_CLASS_MAPPINGS = {} 9 | NODE_DISPLAY_NAME_MAPPINGS = {} 10 | 11 | for module_name in node_list: 12 | imported_module = importlib.import_module(".{}".format(module_name), __name__) 13 | 14 | NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS} 15 | NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS} 16 | 17 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 18 | -------------------------------------------------------------------------------- /assets/rgbx24_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toyxyz/ComfyUI_rgbx_Wrapper/06fdd728252588839f2bfad5d9a3567a2ba7c2e7/assets/rgbx24_teaser.png -------------------------------------------------------------------------------- /download_model.bat: -------------------------------------------------------------------------------- 1 | git-lfs install 2 | git clone https://huggingface.co/zheng95z/x-to-rgb 3 | git clone https://huggingface.co/zheng95z/rgb-to-x 4 | 5 | pause -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_rgbx_wrapper" 3 | description = "This is the rgb2x wrapper node for ComfyUI. The required models are automatically downloaded on the first run.\noriginal project : [a/https://github.com/zheng95z/rgbx](original project : https://github.com/zheng95z/rgbx)" 4 | version = "1.0.1" 5 | license = {file = "LICENSE"} 6 | dependencies = ["torch", "diffusers", "imageio", "numpy", "opencv-python", "transformers", "huggingface-hub"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/toyxyz/ComfyUI_rgbx_Wrapper" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "toyxyz" 14 | DisplayName = "ComfyUI_rgbx_Wrapper" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | diffusers 3 | imageio 4 | numpy 5 | opencv-python 6 | transformers 7 | huggingface-hub -------------------------------------------------------------------------------- /rgb2x/example/Castlereagh_corridor_photo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toyxyz/ComfyUI_rgbx_Wrapper/06fdd728252588839f2bfad5d9a3567a2ba7c2e7/rgb2x/example/Castlereagh_corridor_photo.png -------------------------------------------------------------------------------- /rgb2x/gradio_demo_rgb2x.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 4 | 5 | import gradio as gr 6 | import torch 7 | import torchvision 8 | from diffusers import DDIMScheduler 9 | from load_image import load_exr_image, load_ldr_image 10 | from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline 11 | 12 | current_directory = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | 15 | def get_rgb2x_demo(): 16 | # Load pipeline 17 | pipe = StableDiffusionAOVMatEstPipeline.from_pretrained( 18 | "zheng95z/rgb-to-x", 19 | torch_dtype=torch.float16, 20 | cache_dir=os.path.join(current_directory, "model_cache"), 21 | ).to("cuda") 22 | pipe.scheduler = DDIMScheduler.from_config( 23 | pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" 24 | ) 25 | pipe.set_progress_bar_config(disable=True) 26 | pipe.to("cuda") 27 | 28 | # Augmentation 29 | def callback( 30 | photo, 31 | seed, 32 | inference_step, 33 | num_samples, 34 | ): 35 | generator = torch.Generator(device="cuda").manual_seed(seed) 36 | 37 | if photo.name.endswith(".exr"): 38 | photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda") 39 | elif ( 40 | photo.name.endswith(".png") 41 | or photo.name.endswith(".jpg") 42 | or photo.name.endswith(".jpeg") 43 | ): 44 | photo = load_ldr_image(photo.name, from_srgb=True).to("cuda") 45 | 46 | # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop 47 | old_height = photo.shape[1] 48 | old_width = photo.shape[2] 49 | new_height = old_height 50 | new_width = old_width 51 | radio = old_height / old_width 52 | max_side = 1000 53 | if old_height > old_width: 54 | new_height = max_side 55 | new_width = int(new_height / radio) 56 | else: 57 | new_width = max_side 58 | new_height = int(new_width * radio) 59 | 60 | if new_width % 8 != 0 or new_height % 8 != 0: 61 | new_width = new_width // 8 * 8 62 | new_height = new_height // 8 * 8 63 | 64 | photo = torchvision.transforms.Resize((new_height, new_width))(photo) 65 | 66 | required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] 67 | prompts = { 68 | "albedo": "Albedo (diffuse basecolor)", 69 | "normal": "Camera-space Normal", 70 | "roughness": "Roughness", 71 | "metallic": "Metallicness", 72 | "irradiance": "Irradiance (diffuse lighting)", 73 | } 74 | 75 | return_list = [] 76 | for i in range(num_samples): 77 | for aov_name in required_aovs: 78 | prompt = prompts[aov_name] 79 | generated_image = pipe( 80 | prompt=prompt, 81 | photo=photo, 82 | num_inference_steps=inference_step, 83 | height=new_height, 84 | width=new_width, 85 | generator=generator, 86 | required_aovs=[aov_name], 87 | ).images[0][0] 88 | 89 | generated_image = torchvision.transforms.Resize( 90 | (old_height, old_width) 91 | )(generated_image) 92 | 93 | generated_image = (generated_image, f"Generated {aov_name} {i}") 94 | return_list.append(generated_image) 95 | 96 | return return_list 97 | 98 | block = gr.Blocks() 99 | with block: 100 | with gr.Row(): 101 | gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)") 102 | with gr.Row(): 103 | # Input side 104 | with gr.Column(): 105 | gr.Markdown("### Given Image") 106 | photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"]) 107 | 108 | gr.Markdown("### Parameters") 109 | run_button = gr.Button(label="Run") 110 | with gr.Accordion("Advanced options", open=False): 111 | seed = gr.Slider( 112 | label="Seed", 113 | minimum=-1, 114 | maximum=2147483647, 115 | step=1, 116 | randomize=True, 117 | ) 118 | inference_step = gr.Slider( 119 | label="Inference Step", 120 | minimum=1, 121 | maximum=100, 122 | step=1, 123 | value=50, 124 | ) 125 | num_samples = gr.Slider( 126 | label="Samples", 127 | minimum=1, 128 | maximum=100, 129 | step=1, 130 | value=1, 131 | ) 132 | 133 | # Output side 134 | with gr.Column(): 135 | gr.Markdown("### Output Gallery") 136 | result_gallery = gr.Gallery( 137 | label="Output", 138 | show_label=False, 139 | elem_id="gallery", 140 | columns=2, 141 | ) 142 | 143 | inputs = [ 144 | photo, 145 | seed, 146 | inference_step, 147 | num_samples, 148 | ] 149 | run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True) 150 | 151 | return block 152 | 153 | 154 | if __name__ == "__main__": 155 | demo = get_rgb2x_demo() 156 | demo.queue(max_size=1) 157 | demo.launch() 158 | -------------------------------------------------------------------------------- /rgb2x/load_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import torch 5 | 6 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 7 | import numpy as np 8 | 9 | 10 | def convert_rgb_2_XYZ(rgb): 11 | # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html 12 | # rgb: (h, w, 3) 13 | # XYZ: (h, w, 3) 14 | XYZ = torch.ones_like(rgb) 15 | XYZ[:, :, 0] = ( 16 | 0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2] 17 | ) 18 | XYZ[:, :, 1] = ( 19 | 0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2] 20 | ) 21 | XYZ[:, :, 2] = ( 22 | 0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2] 23 | ) 24 | return XYZ 25 | 26 | 27 | def convert_XYZ_2_Yxy(XYZ): 28 | # XYZ: (h, w, 3) 29 | # Yxy: (h, w, 3) 30 | Yxy = torch.ones_like(XYZ) 31 | Yxy[:, :, 0] = XYZ[:, :, 1] 32 | sum = torch.sum(XYZ, dim=2) 33 | inv_sum = 1.0 / torch.clamp(sum, min=1e-4) 34 | Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum 35 | Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum 36 | return Yxy 37 | 38 | 39 | def convert_rgb_2_Yxy(rgb): 40 | # rgb: (h, w, 3) 41 | # Yxy: (h, w, 3) 42 | return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb)) 43 | 44 | 45 | def convert_XYZ_2_rgb(XYZ): 46 | # XYZ: (h, w, 3) 47 | # rgb: (h, w, 3) 48 | rgb = torch.ones_like(XYZ) 49 | rgb[:, :, 0] = ( 50 | 3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2] 51 | ) 52 | rgb[:, :, 1] = ( 53 | -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2] 54 | ) 55 | rgb[:, :, 2] = ( 56 | 0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2] 57 | ) 58 | return rgb 59 | 60 | 61 | def convert_Yxy_2_XYZ(Yxy): 62 | # Yxy: (h, w, 3) 63 | # XYZ: (h, w, 3) 64 | XYZ = torch.ones_like(Yxy) 65 | XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0] 66 | XYZ[:, :, 1] = Yxy[:, :, 0] 67 | XYZ[:, :, 2] = ( 68 | (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2]) 69 | / torch.clamp(Yxy[:, :, 2], min=1e-4) 70 | * Yxy[:, :, 0] 71 | ) 72 | return XYZ 73 | 74 | 75 | def convert_Yxy_2_rgb(Yxy): 76 | # Yxy: (h, w, 3) 77 | # rgb: (h, w, 3) 78 | return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy)) 79 | 80 | 81 | def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False): 82 | # Load png or jpg image 83 | image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) 84 | image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c) 85 | image[~torch.isfinite(image)] = 0 86 | if from_srgb: 87 | # Convert from sRGB to linear RGB 88 | image = image**2.2 89 | if clamp: 90 | image = torch.clamp(image, min=0.0, max=1.0) 91 | if normalize: 92 | # Normalize to [-1, 1] 93 | image = image * 2.0 - 1.0 94 | image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6) 95 | return image.permute(2, 0, 1) # returns (c, h, w) 96 | 97 | 98 | def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False): 99 | image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB) 100 | image = torch.from_numpy(image.astype("float32")) # (h, w, c) 101 | image[~torch.isfinite(image)] = 0 102 | if tonemaping: 103 | # Exposure adjuestment 104 | image_Yxy = convert_rgb_2_Yxy(image) 105 | lum = ( 106 | image[:, :, 0:1] * 0.2125 107 | + image[:, :, 1:2] * 0.7154 108 | + image[:, :, 2:3] * 0.0721 109 | ) 110 | lum = torch.log(torch.clamp(lum, min=1e-6)) 111 | lum_mean = torch.exp(torch.mean(lum)) 112 | lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6) 113 | image_Yxy[:, :, 0:1] = lp 114 | image = convert_Yxy_2_rgb(image_Yxy) 115 | if clamp: 116 | image = torch.clamp(image, min=0.0, max=1.0) 117 | if normalize: 118 | image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6) 119 | return image.permute(2, 0, 1) # returns (c, h, w) 120 | -------------------------------------------------------------------------------- /rgb2x/pipeline_rgb2x.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import PIL 7 | import torch 8 | from diffusers.configuration_utils import register_to_config 9 | from diffusers.image_processor import VaeImageProcessor 10 | from diffusers.loaders import ( 11 | LoraLoaderMixin, 12 | TextualInversionLoaderMixin, 13 | ) 14 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 15 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 16 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 17 | rescale_noise_cfg, 18 | ) 19 | from diffusers.schedulers import KarrasDiffusionSchedulers 20 | from diffusers.utils import ( 21 | CONFIG_NAME, 22 | BaseOutput, 23 | deprecate, 24 | logging, 25 | ) 26 | from diffusers.utils.torch_utils import randn_tensor 27 | from transformers import CLIPTextModel, CLIPTokenizer 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | class VaeImageProcrssorAOV(VaeImageProcessor): 33 | """ 34 | Image processor for VAE AOV. 35 | 36 | Args: 37 | do_resize (`bool`, *optional*, defaults to `True`): 38 | Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. 39 | vae_scale_factor (`int`, *optional*, defaults to `8`): 40 | VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. 41 | resample (`str`, *optional*, defaults to `lanczos`): 42 | Resampling filter to use when resizing the image. 43 | do_normalize (`bool`, *optional*, defaults to `True`): 44 | Whether to normalize the image to [-1,1]. 45 | """ 46 | 47 | config_name = CONFIG_NAME 48 | 49 | @register_to_config 50 | def __init__( 51 | self, 52 | do_resize: bool = True, 53 | vae_scale_factor: int = 8, 54 | resample: str = "lanczos", 55 | do_normalize: bool = True, 56 | ): 57 | super().__init__() 58 | 59 | def postprocess( 60 | self, 61 | image: torch.FloatTensor, 62 | output_type: str = "pil", 63 | do_denormalize: Optional[List[bool]] = None, 64 | do_gamma_correction: bool = True, 65 | ): 66 | if not isinstance(image, torch.Tensor): 67 | raise ValueError( 68 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" 69 | ) 70 | if output_type not in ["latent", "pt", "np", "pil"]: 71 | deprecation_message = ( 72 | f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " 73 | "`pil`, `np`, `pt`, `latent`" 74 | ) 75 | deprecate( 76 | "Unsupported output_type", 77 | "1.0.0", 78 | deprecation_message, 79 | standard_warn=False, 80 | ) 81 | output_type = "np" 82 | 83 | if output_type == "latent": 84 | return image 85 | 86 | if do_denormalize is None: 87 | do_denormalize = [self.config.do_normalize] * image.shape[0] 88 | 89 | image = torch.stack( 90 | [ 91 | self.denormalize(image[i]) if do_denormalize[i] else image[i] 92 | for i in range(image.shape[0]) 93 | ] 94 | ) 95 | 96 | # Gamma correction 97 | if do_gamma_correction: 98 | image = torch.pow(image, 1.0 / 2.2) 99 | 100 | if output_type == "pt": 101 | return image 102 | 103 | image = self.pt_to_numpy(image) 104 | 105 | if output_type == "np": 106 | return image 107 | 108 | if output_type == "pil": 109 | return self.numpy_to_pil(image) 110 | 111 | def preprocess_normal( 112 | self, 113 | image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], 114 | height: Optional[int] = None, 115 | width: Optional[int] = None, 116 | ) -> torch.Tensor: 117 | image = torch.stack([image], axis=0) 118 | return image 119 | 120 | 121 | @dataclass 122 | class StableDiffusionAOVPipelineOutput(BaseOutput): 123 | """ 124 | Output class for Stable Diffusion AOV pipelines. 125 | 126 | Args: 127 | images (`List[PIL.Image.Image]` or `np.ndarray`) 128 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 129 | num_channels)`. 130 | nsfw_content_detected (`List[bool]`) 131 | List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or 132 | `None` if safety checking could not be performed. 133 | """ 134 | 135 | images: Union[List[PIL.Image.Image], np.ndarray] 136 | 137 | 138 | class StableDiffusionAOVMatEstPipeline( 139 | DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin 140 | ): 141 | r""" 142 | Pipeline for AOVs. 143 | 144 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 145 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 146 | 147 | The pipeline also inherits the following loading methods: 148 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 149 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 150 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 151 | 152 | Args: 153 | vae ([`AutoencoderKL`]): 154 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 155 | text_encoder ([`~transformers.CLIPTextModel`]): 156 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 157 | tokenizer ([`~transformers.CLIPTokenizer`]): 158 | A `CLIPTokenizer` to tokenize text. 159 | unet ([`UNet2DConditionModel`]): 160 | A `UNet2DConditionModel` to denoise the encoded image latents. 161 | scheduler ([`SchedulerMixin`]): 162 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 163 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 164 | """ 165 | 166 | def __init__( 167 | self, 168 | vae: AutoencoderKL, 169 | text_encoder: CLIPTextModel, 170 | tokenizer: CLIPTokenizer, 171 | unet: UNet2DConditionModel, 172 | scheduler: KarrasDiffusionSchedulers, 173 | ): 174 | super().__init__() 175 | 176 | self.register_modules( 177 | vae=vae, 178 | text_encoder=text_encoder, 179 | tokenizer=tokenizer, 180 | unet=unet, 181 | scheduler=scheduler, 182 | ) 183 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 184 | self.image_processor = VaeImageProcrssorAOV( 185 | vae_scale_factor=self.vae_scale_factor 186 | ) 187 | self.register_to_config() 188 | 189 | def _encode_prompt( 190 | self, 191 | prompt, 192 | device, 193 | num_images_per_prompt, 194 | do_classifier_free_guidance, 195 | negative_prompt=None, 196 | prompt_embeds: Optional[torch.FloatTensor] = None, 197 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 198 | ): 199 | r""" 200 | Encodes the prompt into text encoder hidden states. 201 | 202 | Args: 203 | prompt (`str` or `List[str]`, *optional*): 204 | prompt to be encoded 205 | device: (`torch.device`): 206 | torch device 207 | num_images_per_prompt (`int`): 208 | number of images that should be generated per prompt 209 | do_classifier_free_guidance (`bool`): 210 | whether to use classifier free guidance or not 211 | negative_ prompt (`str` or `List[str]`, *optional*): 212 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 213 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 214 | less than `1`). 215 | prompt_embeds (`torch.FloatTensor`, *optional*): 216 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 217 | provided, text embeddings will be generated from `prompt` input argument. 218 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 219 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 220 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 221 | argument. 222 | """ 223 | if prompt is not None and isinstance(prompt, str): 224 | batch_size = 1 225 | elif prompt is not None and isinstance(prompt, list): 226 | batch_size = len(prompt) 227 | else: 228 | batch_size = prompt_embeds.shape[0] 229 | 230 | if prompt_embeds is None: 231 | # textual inversion: procecss multi-vector tokens if necessary 232 | if isinstance(self, TextualInversionLoaderMixin): 233 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 234 | 235 | text_inputs = self.tokenizer( 236 | prompt, 237 | padding="max_length", 238 | max_length=self.tokenizer.model_max_length, 239 | truncation=True, 240 | return_tensors="pt", 241 | ) 242 | text_input_ids = text_inputs.input_ids 243 | untruncated_ids = self.tokenizer( 244 | prompt, padding="longest", return_tensors="pt" 245 | ).input_ids 246 | 247 | if untruncated_ids.shape[-1] >= text_input_ids.shape[ 248 | -1 249 | ] and not torch.equal(text_input_ids, untruncated_ids): 250 | removed_text = self.tokenizer.batch_decode( 251 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 252 | ) 253 | logger.warning( 254 | "The following part of your input was truncated because CLIP can only handle sequences up to" 255 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 256 | ) 257 | 258 | if ( 259 | hasattr(self.text_encoder.config, "use_attention_mask") 260 | and self.text_encoder.config.use_attention_mask 261 | ): 262 | attention_mask = text_inputs.attention_mask.to(device) 263 | else: 264 | attention_mask = None 265 | 266 | prompt_embeds = self.text_encoder( 267 | text_input_ids.to(device), 268 | attention_mask=attention_mask, 269 | ) 270 | prompt_embeds = prompt_embeds[0] 271 | 272 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 273 | 274 | bs_embed, seq_len, _ = prompt_embeds.shape 275 | # duplicate text embeddings for each generation per prompt, using mps friendly method 276 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 277 | prompt_embeds = prompt_embeds.view( 278 | bs_embed * num_images_per_prompt, seq_len, -1 279 | ) 280 | 281 | # get unconditional embeddings for classifier free guidance 282 | if do_classifier_free_guidance and negative_prompt_embeds is None: 283 | uncond_tokens: List[str] 284 | if negative_prompt is None: 285 | uncond_tokens = [""] * batch_size 286 | elif type(prompt) is not type(negative_prompt): 287 | raise TypeError( 288 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 289 | f" {type(prompt)}." 290 | ) 291 | elif isinstance(negative_prompt, str): 292 | uncond_tokens = [negative_prompt] 293 | elif batch_size != len(negative_prompt): 294 | raise ValueError( 295 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 296 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 297 | " the batch size of `prompt`." 298 | ) 299 | else: 300 | uncond_tokens = negative_prompt 301 | 302 | # textual inversion: procecss multi-vector tokens if necessary 303 | if isinstance(self, TextualInversionLoaderMixin): 304 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 305 | 306 | max_length = prompt_embeds.shape[1] 307 | uncond_input = self.tokenizer( 308 | uncond_tokens, 309 | padding="max_length", 310 | max_length=max_length, 311 | truncation=True, 312 | return_tensors="pt", 313 | ) 314 | 315 | if ( 316 | hasattr(self.text_encoder.config, "use_attention_mask") 317 | and self.text_encoder.config.use_attention_mask 318 | ): 319 | attention_mask = uncond_input.attention_mask.to(device) 320 | else: 321 | attention_mask = None 322 | 323 | negative_prompt_embeds = self.text_encoder( 324 | uncond_input.input_ids.to(device), 325 | attention_mask=attention_mask, 326 | ) 327 | negative_prompt_embeds = negative_prompt_embeds[0] 328 | 329 | if do_classifier_free_guidance: 330 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 331 | seq_len = negative_prompt_embeds.shape[1] 332 | 333 | negative_prompt_embeds = negative_prompt_embeds.to( 334 | dtype=self.text_encoder.dtype, device=device 335 | ) 336 | 337 | negative_prompt_embeds = negative_prompt_embeds.repeat( 338 | 1, num_images_per_prompt, 1 339 | ) 340 | negative_prompt_embeds = negative_prompt_embeds.view( 341 | batch_size * num_images_per_prompt, seq_len, -1 342 | ) 343 | 344 | # For classifier free guidance, we need to do two forward passes. 345 | # Here we concatenate the unconditional and text embeddings into a single batch 346 | # to avoid doing two forward passes 347 | # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] 348 | prompt_embeds = torch.cat( 349 | [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] 350 | ) 351 | 352 | return prompt_embeds 353 | 354 | def prepare_extra_step_kwargs(self, generator, eta): 355 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 356 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 357 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 358 | # and should be between [0, 1] 359 | 360 | accepts_eta = "eta" in set( 361 | inspect.signature(self.scheduler.step).parameters.keys() 362 | ) 363 | extra_step_kwargs = {} 364 | if accepts_eta: 365 | extra_step_kwargs["eta"] = eta 366 | 367 | # check if the scheduler accepts generator 368 | accepts_generator = "generator" in set( 369 | inspect.signature(self.scheduler.step).parameters.keys() 370 | ) 371 | if accepts_generator: 372 | extra_step_kwargs["generator"] = generator 373 | return extra_step_kwargs 374 | 375 | def check_inputs( 376 | self, 377 | prompt, 378 | callback_steps, 379 | negative_prompt=None, 380 | prompt_embeds=None, 381 | negative_prompt_embeds=None, 382 | ): 383 | if (callback_steps is None) or ( 384 | callback_steps is not None 385 | and (not isinstance(callback_steps, int) or callback_steps <= 0) 386 | ): 387 | raise ValueError( 388 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 389 | f" {type(callback_steps)}." 390 | ) 391 | 392 | if prompt is not None and prompt_embeds is not None: 393 | raise ValueError( 394 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 395 | " only forward one of the two." 396 | ) 397 | elif prompt is None and prompt_embeds is None: 398 | raise ValueError( 399 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 400 | ) 401 | elif prompt is not None and ( 402 | not isinstance(prompt, str) and not isinstance(prompt, list) 403 | ): 404 | raise ValueError( 405 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 406 | ) 407 | 408 | if negative_prompt is not None and negative_prompt_embeds is not None: 409 | raise ValueError( 410 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 411 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 412 | ) 413 | 414 | if prompt_embeds is not None and negative_prompt_embeds is not None: 415 | if prompt_embeds.shape != negative_prompt_embeds.shape: 416 | raise ValueError( 417 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 418 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 419 | f" {negative_prompt_embeds.shape}." 420 | ) 421 | 422 | def prepare_latents( 423 | self, 424 | batch_size, 425 | num_channels_latents, 426 | height, 427 | width, 428 | dtype, 429 | device, 430 | generator, 431 | latents=None, 432 | ): 433 | shape = ( 434 | batch_size, 435 | num_channels_latents, 436 | height // self.vae_scale_factor, 437 | width // self.vae_scale_factor, 438 | ) 439 | if isinstance(generator, list) and len(generator) != batch_size: 440 | raise ValueError( 441 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 442 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 443 | ) 444 | 445 | if latents is None: 446 | latents = randn_tensor( 447 | shape, generator=generator, device=device, dtype=dtype 448 | ) 449 | else: 450 | latents = latents.to(device) 451 | 452 | # scale the initial noise by the standard deviation required by the scheduler 453 | latents = latents * self.scheduler.init_noise_sigma 454 | return latents 455 | 456 | def prepare_image_latents( 457 | self, 458 | image, 459 | batch_size, 460 | num_images_per_prompt, 461 | dtype, 462 | device, 463 | do_classifier_free_guidance, 464 | generator=None, 465 | ): 466 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 467 | raise ValueError( 468 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 469 | ) 470 | 471 | image = image.to(device=device, dtype=dtype) 472 | 473 | batch_size = batch_size * num_images_per_prompt 474 | 475 | if image.shape[1] == 4: 476 | image_latents = image 477 | else: 478 | if isinstance(generator, list) and len(generator) != batch_size: 479 | raise ValueError( 480 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 481 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 482 | ) 483 | 484 | if isinstance(generator, list): 485 | image_latents = [ 486 | self.vae.encode(image[i : i + 1]).latent_dist.mode() 487 | for i in range(batch_size) 488 | ] 489 | image_latents = torch.cat(image_latents, dim=0) 490 | else: 491 | image_latents = self.vae.encode(image).latent_dist.mode() 492 | 493 | if ( 494 | batch_size > image_latents.shape[0] 495 | and batch_size % image_latents.shape[0] == 0 496 | ): 497 | # expand image_latents for batch_size 498 | deprecation_message = ( 499 | f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" 500 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 501 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 502 | " your script to pass as many initial images as text prompts to suppress this warning." 503 | ) 504 | deprecate( 505 | "len(prompt) != len(image)", 506 | "1.0.0", 507 | deprecation_message, 508 | standard_warn=False, 509 | ) 510 | additional_image_per_prompt = batch_size // image_latents.shape[0] 511 | image_latents = torch.cat( 512 | [image_latents] * additional_image_per_prompt, dim=0 513 | ) 514 | elif ( 515 | batch_size > image_latents.shape[0] 516 | and batch_size % image_latents.shape[0] != 0 517 | ): 518 | raise ValueError( 519 | f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." 520 | ) 521 | else: 522 | image_latents = torch.cat([image_latents], dim=0) 523 | 524 | if do_classifier_free_guidance: 525 | uncond_image_latents = torch.zeros_like(image_latents) 526 | image_latents = torch.cat( 527 | [image_latents, image_latents, uncond_image_latents], dim=0 528 | ) 529 | 530 | return image_latents 531 | 532 | @torch.no_grad() 533 | def __call__( 534 | self, 535 | prompt: Union[str, List[str]] = None, 536 | photo: Union[ 537 | torch.FloatTensor, 538 | PIL.Image.Image, 539 | np.ndarray, 540 | List[torch.FloatTensor], 541 | List[PIL.Image.Image], 542 | List[np.ndarray], 543 | ] = None, 544 | height: Optional[int] = None, 545 | width: Optional[int] = None, 546 | num_inference_steps: int = 100, 547 | required_aovs: List[str] = ["albedo"], 548 | negative_prompt: Optional[Union[str, List[str]]] = None, 549 | num_images_per_prompt: Optional[int] = 1, 550 | use_default_scaling_factor: Optional[bool] = False, 551 | guidance_scale: float = 0.0, 552 | image_guidance_scale: float = 0.0, 553 | guidance_rescale: float = 0.0, 554 | eta: float = 0.0, 555 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 556 | latents: Optional[torch.FloatTensor] = None, 557 | prompt_embeds: Optional[torch.FloatTensor] = None, 558 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 559 | output_type: Optional[str] = "pil", 560 | return_dict: bool = True, 561 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 562 | callback_steps: int = 1, 563 | ): 564 | r""" 565 | The call function to the pipeline for generation. 566 | 567 | Args: 568 | prompt (`str` or `List[str]`, *optional*): 569 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 570 | image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 571 | `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept 572 | image latents as `image`, but if passing latents directly it is not encoded again. 573 | num_inference_steps (`int`, *optional*, defaults to 100): 574 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 575 | expense of slower inference. 576 | guidance_scale (`float`, *optional*, defaults to 7.5): 577 | A higher guidance scale value encourages the model to generate images closely linked to the text 578 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 579 | image_guidance_scale (`float`, *optional*, defaults to 1.5): 580 | Push the generated image towards the inital `image`. Image guidance scale is enabled by setting 581 | `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely 582 | linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a 583 | value of at least `1`. 584 | negative_prompt (`str` or `List[str]`, *optional*): 585 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 586 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 587 | num_images_per_prompt (`int`, *optional*, defaults to 1): 588 | The number of images to generate per prompt. 589 | eta (`float`, *optional*, defaults to 0.0): 590 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 591 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 592 | generator (`torch.Generator`, *optional*): 593 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 594 | generation deterministic. 595 | latents (`torch.FloatTensor`, *optional*): 596 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 597 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 598 | tensor is generated by sampling using the supplied random `generator`. 599 | prompt_embeds (`torch.FloatTensor`, *optional*): 600 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 601 | provided, text embeddings are generated from the `prompt` input argument. 602 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 603 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 604 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 605 | output_type (`str`, *optional*, defaults to `"pil"`): 606 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 607 | return_dict (`bool`, *optional*, defaults to `True`): 608 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 609 | plain tuple. 610 | callback (`Callable`, *optional*): 611 | A function that calls every `callback_steps` steps during inference. The function is called with the 612 | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 613 | callback_steps (`int`, *optional*, defaults to 1): 614 | The frequency at which the `callback` function is called. If not specified, the callback is called at 615 | every step. 616 | 617 | Examples: 618 | 619 | ```py 620 | >>> import PIL 621 | >>> import requests 622 | >>> import torch 623 | >>> from io import BytesIO 624 | 625 | >>> from diffusers import StableDiffusionInstructPix2PixPipeline 626 | 627 | 628 | >>> def download_image(url): 629 | ... response = requests.get(url) 630 | ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") 631 | 632 | 633 | >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" 634 | 635 | >>> image = download_image(img_url).resize((512, 512)) 636 | 637 | >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( 638 | ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 639 | ... ) 640 | >>> pipe = pipe.to("cuda") 641 | 642 | >>> prompt = "make the mountains snowy" 643 | >>> image = pipe(prompt=prompt, image=image).images[0] 644 | ``` 645 | 646 | Returns: 647 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 648 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 649 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 650 | second element is a list of `bool`s indicating whether the corresponding generated image contains 651 | "not-safe-for-work" (nsfw) content. 652 | """ 653 | # 0. Check inputs 654 | self.check_inputs( 655 | prompt, 656 | callback_steps, 657 | negative_prompt, 658 | prompt_embeds, 659 | negative_prompt_embeds, 660 | ) 661 | 662 | # 1. Define call parameters 663 | if prompt is not None and isinstance(prompt, str): 664 | batch_size = 1 665 | elif prompt is not None and isinstance(prompt, list): 666 | batch_size = len(prompt) 667 | else: 668 | batch_size = prompt_embeds.shape[0] 669 | 670 | device = self._execution_device 671 | do_classifier_free_guidance = ( 672 | guidance_scale > 1.0 and image_guidance_scale >= 1.0 673 | ) 674 | # check if scheduler is in sigmas space 675 | scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") 676 | 677 | # 2. Encode input prompt 678 | prompt_embeds = self._encode_prompt( 679 | prompt, 680 | device, 681 | num_images_per_prompt, 682 | do_classifier_free_guidance, 683 | negative_prompt, 684 | prompt_embeds=prompt_embeds, 685 | negative_prompt_embeds=negative_prompt_embeds, 686 | ) 687 | 688 | # 3. Preprocess image 689 | # Normalize image to [-1,1] 690 | preprocessed_photo = self.image_processor.preprocess(photo) 691 | 692 | # 4. set timesteps 693 | self.scheduler.set_timesteps(num_inference_steps, device=device) 694 | timesteps = self.scheduler.timesteps 695 | 696 | # 5. Prepare Image latents 697 | image_latents = self.prepare_image_latents( 698 | preprocessed_photo, 699 | batch_size, 700 | num_images_per_prompt, 701 | prompt_embeds.dtype, 702 | device, 703 | do_classifier_free_guidance, 704 | generator, 705 | ) 706 | image_latents = image_latents * self.vae.config.scaling_factor 707 | 708 | height, width = image_latents.shape[-2:] 709 | height = height * self.vae_scale_factor 710 | width = width * self.vae_scale_factor 711 | 712 | # 6. Prepare latent variables 713 | num_channels_latents = self.unet.config.out_channels 714 | latents = self.prepare_latents( 715 | batch_size * num_images_per_prompt, 716 | num_channels_latents, 717 | height, 718 | width, 719 | prompt_embeds.dtype, 720 | device, 721 | generator, 722 | latents, 723 | ) 724 | 725 | # 7. Check that shapes of latents and image match the UNet channels 726 | num_channels_image = image_latents.shape[1] 727 | if num_channels_latents + num_channels_image != self.unet.config.in_channels: 728 | raise ValueError( 729 | f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" 730 | f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" 731 | f" `num_channels_image`: {num_channels_image} " 732 | f" = {num_channels_latents+num_channels_image}. Please verify the config of" 733 | " `pipeline.unet` or your `image` input." 734 | ) 735 | 736 | # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 737 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 738 | 739 | # 9. Denoising loop 740 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 741 | with self.progress_bar(total=num_inference_steps) as progress_bar: 742 | for i, t in enumerate(timesteps): 743 | # Expand the latents if we are doing classifier free guidance. 744 | # The latents are expanded 3 times because for pix2pix the guidance\ 745 | # is applied for both the text and the input image. 746 | latent_model_input = ( 747 | torch.cat([latents] * 3) if do_classifier_free_guidance else latents 748 | ) 749 | 750 | # concat latents, image_latents in the channel dimension 751 | scaled_latent_model_input = self.scheduler.scale_model_input( 752 | latent_model_input, t 753 | ) 754 | scaled_latent_model_input = torch.cat( 755 | [scaled_latent_model_input, image_latents], dim=1 756 | ) 757 | 758 | # predict the noise residual 759 | noise_pred = self.unet( 760 | scaled_latent_model_input, 761 | t, 762 | encoder_hidden_states=prompt_embeds, 763 | return_dict=False, 764 | )[0] 765 | 766 | # perform guidance 767 | if do_classifier_free_guidance: 768 | ( 769 | noise_pred_text, 770 | noise_pred_image, 771 | noise_pred_uncond, 772 | ) = noise_pred.chunk(3) 773 | noise_pred = ( 774 | noise_pred_uncond 775 | + guidance_scale * (noise_pred_text - noise_pred_image) 776 | + image_guidance_scale * (noise_pred_image - noise_pred_uncond) 777 | ) 778 | 779 | if do_classifier_free_guidance and guidance_rescale > 0.0: 780 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 781 | noise_pred = rescale_noise_cfg( 782 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale 783 | ) 784 | 785 | # compute the previous noisy sample x_t -> x_t-1 786 | latents = self.scheduler.step( 787 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 788 | )[0] 789 | 790 | # call the callback, if provided 791 | if i == len(timesteps) - 1 or ( 792 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 793 | ): 794 | progress_bar.update() 795 | if callback is not None and i % callback_steps == 0: 796 | callback(i, t, latents) 797 | 798 | aov_latents = latents / self.vae.config.scaling_factor 799 | aov = self.vae.decode(aov_latents, return_dict=False)[0] 800 | do_denormalize = [True] * aov.shape[0] 801 | aov_name = required_aovs[0] 802 | if aov_name == "albedo" or aov_name == "irradiance": 803 | do_gamma_correction = True 804 | else: 805 | do_gamma_correction = False 806 | 807 | if aov_name == "roughness" or aov_name == "metallic": 808 | aov = aov[:, 0:1].repeat(1, 3, 1, 1) 809 | 810 | aov = self.image_processor.postprocess( 811 | aov, 812 | output_type=output_type, 813 | do_denormalize=do_denormalize, 814 | do_gamma_correction=do_gamma_correction, 815 | ) 816 | aovs = [aov] 817 | 818 | # Offload last model to CPU 819 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 820 | self.final_offload_hook.offload() 821 | return StableDiffusionAOVPipelineOutput(images=aovs) 822 | -------------------------------------------------------------------------------- /rgbx.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | import comfy.utils 3 | import numpy as np 4 | import cv2 5 | import sys 6 | from pathlib import Path 7 | from nodes import MAX_RESOLUTION, SaveImage, common_ksampler 8 | 9 | import os 10 | import torch 11 | import torchvision 12 | from torchvision.transforms import ToTensor 13 | from diffusers import DDIMScheduler 14 | from .rgb2x.load_image import load_exr_image, load_ldr_image 15 | from .rgb2x.pipeline_rgb2x import StableDiffusionAOVMatEstPipeline 16 | 17 | ImageFile.LOAD_TRUNCATED_IMAGES = True 18 | 19 | 20 | def process_single_aov(torch_image, aov_name='albedo', seed=42, inference_step=50): 21 | """ 22 | 단일 Torch 텐서 이미지의 특정 AOV 맵을 생성하여 torch 텐서로 반환합니다. 23 | 입력 텐서가 BWHC 형식일 경우, RGB 확인 및 변환 후 결과도 BWHC 형식으로 반환합니다. 24 | 25 | Args: 26 | torch_image (torch.Tensor): 처리할 입력 이미지 (B, H, W, C 형식). 27 | aov_name (str): 생성할 AOV 맵의 이름 (기본값: 'albedo'). 28 | seed (int): 랜덤 시드 값. 29 | inference_step (int): 모델 추론 단계 수. 30 | 31 | Returns: 32 | torch.Tensor: 생성된 AOV 맵 텐서 (B, H, W, C 형식). 33 | """ 34 | # 지원되는 AOV 목록 35 | supported_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] 36 | 37 | # AOV 유효성 검사 38 | if aov_name.lower() not in supported_aovs: 39 | raise ValueError(f"지원되지 않는 AOV입니다. 다음 중 하나를 선택하세요: {', '.join(supported_aovs)}") 40 | 41 | # 프롬프트 정의 42 | prompts = { 43 | "albedo": "Albedo (diffuse basecolor)", 44 | "normal": "Camera-space Normal", 45 | "roughness": "Roughness", 46 | "metallic": "Metallicness", 47 | "irradiance": "Irradiance (diffuse lighting)", 48 | } 49 | 50 | # 입력 텐서 확인 51 | if len(torch_image.shape) != 4: 52 | raise ValueError("input tensor must B, H, W, C ") 53 | 54 | # BWHC -> BCHW로 변환 55 | torch_image = torch_image.permute(0, 3, 1, 2) # (B, C, H, W) 56 | 57 | # 배치에서 첫 번째 이미지만 사용 58 | photo = torch_image[0] # 첫 번째 배치 선택 (C, H, W) 59 | 60 | photo = photo**2.2 61 | 62 | # 이미지 크기 조정 (8로 나누어떨어지도록 설정) 63 | old_height, old_width = photo.shape[1], photo.shape[2] 64 | old_aspect_ratio = old_height / old_width 65 | max_side = 1000 66 | 67 | if old_height > old_width: 68 | new_height = max_side 69 | new_width = int(new_height / old_aspect_ratio) 70 | else: 71 | new_width = max_side 72 | new_height = int(new_width * old_aspect_ratio) 73 | 74 | # 8의 배수로 크기 조정 75 | new_width = new_width // 8 * 8 76 | new_height = new_height // 8 * 8 77 | 78 | resize_transform = torchvision.transforms.Resize((new_height, new_width)) 79 | photo = resize_transform(photo.unsqueeze(0)).squeeze(0) # 크기 조정 80 | 81 | # 랜덤 시드 설정 82 | generator = torch.Generator(device="cuda").manual_seed(seed) 83 | 84 | # 선택된 AOV 이미지 생성 85 | prompt = prompts[aov_name.lower()] 86 | pipe = StableDiffusionAOVMatEstPipeline.from_pretrained( 87 | "zheng95z/rgb-to-x", 88 | torch_dtype=torch.float16, 89 | cache_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_cache"), 90 | ).to("cuda") 91 | pipe.scheduler = DDIMScheduler.from_config( 92 | pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" 93 | ) 94 | pipe.set_progress_bar_config(disable=True) 95 | pipe.to("cuda") 96 | 97 | generated_image = pipe( 98 | prompt=prompt, 99 | photo=photo.unsqueeze(0).to("cuda"), # (B=1, C, H, W) 100 | num_inference_steps=inference_step, 101 | height=new_height, 102 | width=new_width, 103 | generator=generator, 104 | required_aovs=[aov_name.lower()], 105 | ).images[0][0] 106 | 107 | # PIL 이미지를 torch 텐서로 변환 108 | generated_image_tensor = ToTensor()(generated_image) # (C, H, W) 109 | 110 | 111 | # BCHW -> BWHC로 변환하여 반환 112 | generated_image_tensor = generated_image_tensor.permute(1, 2, 0).unsqueeze(0) # (B=1, H, W, C) 113 | 114 | return generated_image_tensor 115 | 116 | 117 | 118 | 119 | class rgb2x: 120 | @classmethod 121 | def INPUT_TYPES(s): 122 | return { 123 | "required": { 124 | "image": ("IMAGE",), 125 | "aov": (("albedo", "normal", "roughness", "metallic", "irradiance"), {"default": "albedo"}), 126 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 127 | "steps": ("INT", { "default": 50, "min": 1, "max": 0xffffffffffffffff, "step": 1, }), 128 | } 129 | } 130 | 131 | RETURN_TYPES = ("IMAGE",) 132 | RETURN_NAMES = ("IMAGE",) 133 | FUNCTION = "execute" 134 | CATEGORY = "ToyxyzTestNodes" 135 | 136 | def execute(self, image: torch.Tensor, aov, seed, steps): 137 | 138 | output = process_single_aov(image, aov, seed, steps) 139 | 140 | return(output, ) 141 | 142 | 143 | NODE_CLASS_MAPPINGS = { 144 | "rgb2x": rgb2x, 145 | } 146 | 147 | NODE_DISPLAY_NAME_MAPPINGS = { 148 | "rgb2x": "rgb2x" 149 | } 150 | 151 | -------------------------------------------------------------------------------- /x2rgb/example/kitchen-albedo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toyxyz/ComfyUI_rgbx_Wrapper/06fdd728252588839f2bfad5d9a3567a2ba7c2e7/x2rgb/example/kitchen-albedo.png -------------------------------------------------------------------------------- /x2rgb/example/kitchen-irradiance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toyxyz/ComfyUI_rgbx_Wrapper/06fdd728252588839f2bfad5d9a3567a2ba7c2e7/x2rgb/example/kitchen-irradiance.png -------------------------------------------------------------------------------- /x2rgb/example/kitchen-metallic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toyxyz/ComfyUI_rgbx_Wrapper/06fdd728252588839f2bfad5d9a3567a2ba7c2e7/x2rgb/example/kitchen-metallic.png -------------------------------------------------------------------------------- /x2rgb/example/kitchen-normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toyxyz/ComfyUI_rgbx_Wrapper/06fdd728252588839f2bfad5d9a3567a2ba7c2e7/x2rgb/example/kitchen-normal.png -------------------------------------------------------------------------------- /x2rgb/example/kitchen-ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toyxyz/ComfyUI_rgbx_Wrapper/06fdd728252588839f2bfad5d9a3567a2ba7c2e7/x2rgb/example/kitchen-ref.png -------------------------------------------------------------------------------- /x2rgb/example/kitchen-roughness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/toyxyz/ComfyUI_rgbx_Wrapper/06fdd728252588839f2bfad5d9a3567a2ba7c2e7/x2rgb/example/kitchen-roughness.png -------------------------------------------------------------------------------- /x2rgb/gradio_demo_x2rgb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 4 | 5 | import gradio as gr 6 | import numpy as np 7 | import torch 8 | from diffusers import DDIMScheduler 9 | from load_image import load_exr_image, load_ldr_image 10 | from pipeline_x2rgb import StableDiffusionAOVDropoutPipeline 11 | 12 | current_directory = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | 15 | def get_x2rgb_demo(): 16 | # Load pipeline 17 | pipe = StableDiffusionAOVDropoutPipeline.from_pretrained( 18 | "zheng95z/x-to-rgb", 19 | torch_dtype=torch.float16, 20 | cache_dir=os.path.join(current_directory, "model_cache"), 21 | ).to("cuda") 22 | pipe.scheduler = DDIMScheduler.from_config( 23 | pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" 24 | ) 25 | pipe.set_progress_bar_config(disable=True) 26 | pipe.to("cuda") 27 | 28 | # Augmentation 29 | def callback( 30 | albedo, 31 | normal, 32 | roughness, 33 | metallic, 34 | irradiance, 35 | prompt, 36 | seed, 37 | inference_step, 38 | num_samples, 39 | guidance_scale, 40 | image_guidance_scale, 41 | ): 42 | if albedo is None: 43 | albedo_image = None 44 | elif albedo.name.endswith(".exr"): 45 | albedo_image = load_exr_image(albedo.name, clamp=True).to("cuda") 46 | elif ( 47 | albedo.name.endswith(".png") 48 | or albedo.name.endswith(".jpg") 49 | or albedo.name.endswith(".jpeg") 50 | ): 51 | albedo_image = load_ldr_image(albedo.name, from_srgb=True).to("cuda") 52 | 53 | if normal is None: 54 | normal_image = None 55 | elif normal.name.endswith(".exr"): 56 | normal_image = load_exr_image(normal.name, normalize=True).to("cuda") 57 | elif ( 58 | normal.name.endswith(".png") 59 | or normal.name.endswith(".jpg") 60 | or normal.name.endswith(".jpeg") 61 | ): 62 | normal_image = load_ldr_image(normal.name, normalize=True).to("cuda") 63 | 64 | if roughness is None: 65 | roughness_image = None 66 | elif roughness.name.endswith(".exr"): 67 | roughness_image = load_exr_image(roughness.name, clamp=True).to("cuda") 68 | elif ( 69 | roughness.name.endswith(".png") 70 | or roughness.name.endswith(".jpg") 71 | or roughness.name.endswith(".jpeg") 72 | ): 73 | roughness_image = load_ldr_image(roughness.name, clamp=True).to("cuda") 74 | 75 | if metallic is None: 76 | metallic_image = None 77 | elif metallic.name.endswith(".exr"): 78 | metallic_image = load_exr_image(metallic.name, clamp=True).to("cuda") 79 | elif ( 80 | metallic.name.endswith(".png") 81 | or metallic.name.endswith(".jpg") 82 | or metallic.name.endswith(".jpeg") 83 | ): 84 | metallic_image = load_ldr_image(metallic.name, clamp=True).to("cuda") 85 | 86 | if irradiance is None: 87 | irradiance_image = None 88 | elif irradiance.name.endswith(".exr"): 89 | irradiance_image = load_exr_image( 90 | irradiance.name, tonemaping=True, clamp=True 91 | ).to("cuda") 92 | elif ( 93 | irradiance.name.endswith(".png") 94 | or irradiance.name.endswith(".jpg") 95 | or irradiance.name.endswith(".jpeg") 96 | ): 97 | irradiance_image = load_ldr_image( 98 | irradiance.name, from_srgb=True, clamp=True 99 | ).to("cuda") 100 | 101 | generator = torch.Generator(device="cuda").manual_seed(seed) 102 | 103 | # Set default height and width 104 | height = 768 105 | width = 768 106 | 107 | # Check if any of the input images are not None 108 | # and set the height and width accordingly 109 | images = [ 110 | albedo_image, 111 | normal_image, 112 | roughness_image, 113 | metallic_image, 114 | irradiance_image, 115 | ] 116 | for img in images: 117 | if img is not None: 118 | height = img.shape[1] 119 | width = img.shape[2] 120 | break 121 | 122 | required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] 123 | return_list = [] 124 | for i in range(num_samples): 125 | generated_image = pipe( 126 | prompt=prompt, 127 | albedo=albedo_image, 128 | normal=normal_image, 129 | roughness=roughness_image, 130 | metallic=metallic_image, 131 | irradiance=irradiance_image, 132 | num_inference_steps=inference_step, 133 | height=height, 134 | width=width, 135 | generator=generator, 136 | required_aovs=required_aovs, 137 | guidance_scale=guidance_scale, 138 | image_guidance_scale=image_guidance_scale, 139 | guidance_rescale=0.7, 140 | output_type="np", 141 | ).images[0] 142 | 143 | generated_image = (generated_image, f"Generated Image {i}") 144 | return_list.append(generated_image) 145 | 146 | if albedo_image is not None: 147 | albedo_image = albedo_image ** (1 / 2.2) 148 | albedo_image = albedo_image.cpu().numpy().transpose(1, 2, 0) 149 | else: 150 | albedo_image = np.zeros((height, width, 3)) 151 | 152 | if normal_image is not None: 153 | normal_image = normal_image * 0.5 + 0.5 154 | normal_image = normal_image.cpu().numpy().transpose(1, 2, 0) 155 | else: 156 | normal_image = np.zeros((height, width, 3)) 157 | 158 | if roughness_image is not None: 159 | roughness_image = roughness_image.cpu().numpy().transpose(1, 2, 0) 160 | else: 161 | roughness_image = np.zeros((height, width, 3)) 162 | 163 | if metallic_image is not None: 164 | metallic_image = metallic_image.cpu().numpy().transpose(1, 2, 0) 165 | else: 166 | metallic_image = np.zeros((height, width, 3)) 167 | 168 | if irradiance_image is not None: 169 | irradiance_image = irradiance_image ** (1 / 2.2) 170 | irradiance_image = irradiance_image.cpu().numpy().transpose(1, 2, 0) 171 | else: 172 | irradiance_image = np.zeros((height, width, 3)) 173 | 174 | albedo_image = (albedo_image, "Albedo") 175 | normal_image = (normal_image, "Normal") 176 | roughness_image = (roughness_image, "Roughness") 177 | metallic_image = (metallic_image, "Metallic") 178 | irradiance_image = (irradiance_image, "Irradiance") 179 | 180 | return_list.append(albedo_image) 181 | return_list.append(normal_image) 182 | return_list.append(roughness_image) 183 | return_list.append(metallic_image) 184 | return_list.append(irradiance_image) 185 | 186 | return return_list 187 | 188 | block = gr.Blocks() 189 | with block: 190 | with gr.Row(): 191 | gr.Markdown("## Model X -> RGB (Intrinsic channels -> realistic image)") 192 | with gr.Row(): 193 | # Input side 194 | with gr.Column(): 195 | gr.Markdown("### Given intrinsic channels") 196 | albedo = gr.File(label="Albedo", file_types=[".exr", ".png", ".jpg"]) 197 | normal = gr.File(label="Normal", file_types=[".exr", ".png", ".jpg"]) 198 | roughness = gr.File( 199 | label="Roughness", file_types=[".exr", ".png", ".jpg"] 200 | ) 201 | metallic = gr.File( 202 | label="Metallic", file_types=[".exr", ".png", ".jpg"] 203 | ) 204 | irradiance = gr.File( 205 | label="Irradiance", file_types=[".exr", ".png", ".jpg"] 206 | ) 207 | 208 | gr.Markdown("### Parameters") 209 | prompt = gr.Textbox(label="Prompt") 210 | run_button = gr.Button(label="Run") 211 | with gr.Accordion("Advanced options", open=False): 212 | seed = gr.Slider( 213 | label="Seed", 214 | minimum=-1, 215 | maximum=2147483647, 216 | step=1, 217 | randomize=True, 218 | ) 219 | inference_step = gr.Slider( 220 | label="Inference Step", 221 | minimum=1, 222 | maximum=100, 223 | step=1, 224 | value=50, 225 | ) 226 | num_samples = gr.Slider( 227 | label="Samples", 228 | minimum=1, 229 | maximum=100, 230 | step=1, 231 | value=1, 232 | ) 233 | guidance_scale = gr.Slider( 234 | label="Guidance Scale", 235 | minimum=0.0, 236 | maximum=10.0, 237 | step=0.1, 238 | value=7.5, 239 | ) 240 | image_guidance_scale = gr.Slider( 241 | label="Image Guidance Scale", 242 | minimum=0.0, 243 | maximum=10.0, 244 | step=0.1, 245 | value=1.5, 246 | ) 247 | 248 | # Output side 249 | with gr.Column(): 250 | gr.Markdown("### Output Gallery") 251 | result_gallery = gr.Gallery( 252 | label="Output", 253 | show_label=False, 254 | elem_id="gallery", 255 | columns=2, 256 | ) 257 | 258 | inputs = [ 259 | albedo, 260 | normal, 261 | roughness, 262 | metallic, 263 | irradiance, 264 | prompt, 265 | seed, 266 | inference_step, 267 | num_samples, 268 | guidance_scale, 269 | image_guidance_scale, 270 | ] 271 | run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True) 272 | 273 | return block 274 | 275 | 276 | if __name__ == "__main__": 277 | demo = get_x2rgb_demo() 278 | demo.queue(max_size=1) 279 | demo.launch() 280 | -------------------------------------------------------------------------------- /x2rgb/load_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import torch 5 | 6 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 7 | import numpy as np 8 | 9 | 10 | def convert_rgb_2_XYZ(rgb): 11 | # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html 12 | # rgb: (h, w, 3) 13 | # XYZ: (h, w, 3) 14 | XYZ = torch.ones_like(rgb) 15 | XYZ[:, :, 0] = ( 16 | 0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2] 17 | ) 18 | XYZ[:, :, 1] = ( 19 | 0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2] 20 | ) 21 | XYZ[:, :, 2] = ( 22 | 0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2] 23 | ) 24 | return XYZ 25 | 26 | 27 | def convert_XYZ_2_Yxy(XYZ): 28 | # XYZ: (h, w, 3) 29 | # Yxy: (h, w, 3) 30 | Yxy = torch.ones_like(XYZ) 31 | Yxy[:, :, 0] = XYZ[:, :, 1] 32 | sum = torch.sum(XYZ, dim=2) 33 | inv_sum = 1.0 / torch.clamp(sum, min=1e-4) 34 | Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum 35 | Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum 36 | return Yxy 37 | 38 | 39 | def convert_rgb_2_Yxy(rgb): 40 | # rgb: (h, w, 3) 41 | # Yxy: (h, w, 3) 42 | return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb)) 43 | 44 | 45 | def convert_XYZ_2_rgb(XYZ): 46 | # XYZ: (h, w, 3) 47 | # rgb: (h, w, 3) 48 | rgb = torch.ones_like(XYZ) 49 | rgb[:, :, 0] = ( 50 | 3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2] 51 | ) 52 | rgb[:, :, 1] = ( 53 | -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2] 54 | ) 55 | rgb[:, :, 2] = ( 56 | 0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2] 57 | ) 58 | return rgb 59 | 60 | 61 | def convert_Yxy_2_XYZ(Yxy): 62 | # Yxy: (h, w, 3) 63 | # XYZ: (h, w, 3) 64 | XYZ = torch.ones_like(Yxy) 65 | XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0] 66 | XYZ[:, :, 1] = Yxy[:, :, 0] 67 | XYZ[:, :, 2] = ( 68 | (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2]) 69 | / torch.clamp(Yxy[:, :, 2], min=1e-4) 70 | * Yxy[:, :, 0] 71 | ) 72 | return XYZ 73 | 74 | 75 | def convert_Yxy_2_rgb(Yxy): 76 | # Yxy: (h, w, 3) 77 | # rgb: (h, w, 3) 78 | return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy)) 79 | 80 | 81 | def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False): 82 | # Load png or jpg image 83 | image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) 84 | image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c) 85 | image[~torch.isfinite(image)] = 0 86 | if from_srgb: 87 | # Convert from sRGB to linear RGB 88 | image = image**2.2 89 | if clamp: 90 | image = torch.clamp(image, min=0.0, max=1.0) 91 | if normalize: 92 | # Normalize to [-1, 1] 93 | image = image * 2.0 - 1.0 94 | image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6) 95 | return image.permute(2, 0, 1) # returns (c, h, w) 96 | 97 | 98 | def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False): 99 | image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB) 100 | image = torch.from_numpy(image.astype("float32")) # (h, w, c) 101 | image[~torch.isfinite(image)] = 0 102 | if tonemaping: 103 | # Exposure adjuestment 104 | image_Yxy = convert_rgb_2_Yxy(image) 105 | lum = ( 106 | image[:, :, 0:1] * 0.2125 107 | + image[:, :, 1:2] * 0.7154 108 | + image[:, :, 2:3] * 0.0721 109 | ) 110 | lum = torch.log(torch.clamp(lum, min=1e-6)) 111 | lum_mean = torch.exp(torch.mean(lum)) 112 | lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6) 113 | image_Yxy[:, :, 0:1] = lp 114 | image = convert_Yxy_2_rgb(image_Yxy) 115 | if clamp: 116 | image = torch.clamp(image, min=0.0, max=1.0) 117 | if normalize: 118 | image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6) 119 | return image.permute(2, 0, 1) # returns (c, h, w) 120 | -------------------------------------------------------------------------------- /x2rgb/pipeline_x2rgb.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import PIL 7 | import torch 8 | import torch.nn.functional as F 9 | from diffusers.configuration_utils import register_to_config 10 | from diffusers.image_processor import VaeImageProcessor 11 | from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin 12 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 13 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 14 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 15 | rescale_noise_cfg, 16 | ) 17 | from diffusers.schedulers import KarrasDiffusionSchedulers 18 | from diffusers.utils import CONFIG_NAME, BaseOutput, deprecate, logging, randn_tensor 19 | from packaging import version 20 | from transformers import CLIPTextModel, CLIPTokenizer 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | class VaeImageProcrssorAOV(VaeImageProcessor): 26 | """ 27 | Image processor for VAE AOV. 28 | 29 | Args: 30 | do_resize (`bool`, *optional*, defaults to `True`): 31 | Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. 32 | vae_scale_factor (`int`, *optional*, defaults to `8`): 33 | VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. 34 | resample (`str`, *optional*, defaults to `lanczos`): 35 | Resampling filter to use when resizing the image. 36 | do_normalize (`bool`, *optional*, defaults to `True`): 37 | Whether to normalize the image to [-1,1]. 38 | """ 39 | 40 | config_name = CONFIG_NAME 41 | 42 | @register_to_config 43 | def __init__( 44 | self, 45 | do_resize: bool = True, 46 | vae_scale_factor: int = 8, 47 | resample: str = "lanczos", 48 | do_normalize: bool = True, 49 | ): 50 | super().__init__() 51 | 52 | def postprocess( 53 | self, 54 | image: torch.FloatTensor, 55 | output_type: str = "pil", 56 | do_denormalize: Optional[List[bool]] = None, 57 | do_gamma_correction: bool = True, 58 | ): 59 | if not isinstance(image, torch.Tensor): 60 | raise ValueError( 61 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" 62 | ) 63 | if output_type not in ["latent", "pt", "np", "pil"]: 64 | deprecation_message = ( 65 | f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " 66 | "`pil`, `np`, `pt`, `latent`" 67 | ) 68 | deprecate( 69 | "Unsupported output_type", 70 | "1.0.0", 71 | deprecation_message, 72 | standard_warn=False, 73 | ) 74 | output_type = "np" 75 | 76 | if output_type == "latent": 77 | return image 78 | 79 | if do_denormalize is None: 80 | do_denormalize = [self.config.do_normalize] * image.shape[0] 81 | 82 | image = torch.stack( 83 | [ 84 | self.denormalize(image[i]) if do_denormalize[i] else image[i] 85 | for i in range(image.shape[0]) 86 | ] 87 | ) 88 | 89 | # Gamma correction 90 | if do_gamma_correction: 91 | image = torch.pow(image, 1.0 / 2.2) 92 | 93 | if output_type == "pt": 94 | return image 95 | 96 | image = self.pt_to_numpy(image) 97 | 98 | if output_type == "np": 99 | return image 100 | 101 | if output_type == "pil": 102 | return self.numpy_to_pil(image) 103 | 104 | def preprocess_normal( 105 | self, 106 | image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], 107 | height: Optional[int] = None, 108 | width: Optional[int] = None, 109 | ) -> torch.Tensor: 110 | image = torch.stack([image], axis=0) 111 | return image 112 | 113 | 114 | @dataclass 115 | class StableDiffusionAOVPipelineOutput(BaseOutput): 116 | """ 117 | Output class for Stable Diffusion AOV pipelines. 118 | 119 | Args: 120 | images (`List[PIL.Image.Image]` or `np.ndarray`) 121 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 122 | num_channels)`. 123 | nsfw_content_detected (`List[bool]`) 124 | List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or 125 | `None` if safety checking could not be performed. 126 | """ 127 | 128 | images: Union[List[PIL.Image.Image], np.ndarray] 129 | predicted_x0_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] = None 130 | 131 | 132 | class StableDiffusionAOVDropoutPipeline( 133 | DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin 134 | ): 135 | r""" 136 | Pipeline for AOVs. 137 | 138 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 139 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 140 | 141 | The pipeline also inherits the following loading methods: 142 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 143 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 144 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 145 | 146 | Args: 147 | vae ([`AutoencoderKL`]): 148 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 149 | text_encoder ([`~transformers.CLIPTextModel`]): 150 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 151 | tokenizer ([`~transformers.CLIPTokenizer`]): 152 | A `CLIPTokenizer` to tokenize text. 153 | unet ([`UNet2DConditionModel`]): 154 | A `UNet2DConditionModel` to denoise the encoded image latents. 155 | scheduler ([`SchedulerMixin`]): 156 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 157 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 158 | """ 159 | 160 | def __init__( 161 | self, 162 | vae: AutoencoderKL, 163 | text_encoder: CLIPTextModel, 164 | tokenizer: CLIPTokenizer, 165 | unet: UNet2DConditionModel, 166 | scheduler: KarrasDiffusionSchedulers, 167 | ): 168 | super().__init__() 169 | 170 | self.register_modules( 171 | vae=vae, 172 | text_encoder=text_encoder, 173 | tokenizer=tokenizer, 174 | unet=unet, 175 | scheduler=scheduler, 176 | ) 177 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 178 | self.image_processor = VaeImageProcrssorAOV( 179 | vae_scale_factor=self.vae_scale_factor 180 | ) 181 | self.register_to_config() 182 | 183 | def _encode_prompt( 184 | self, 185 | prompt, 186 | device, 187 | num_images_per_prompt, 188 | do_classifier_free_guidance, 189 | negative_prompt=None, 190 | prompt_embeds: Optional[torch.FloatTensor] = None, 191 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 192 | ): 193 | r""" 194 | Encodes the prompt into text encoder hidden states. 195 | 196 | Args: 197 | prompt (`str` or `List[str]`, *optional*): 198 | prompt to be encoded 199 | device: (`torch.device`): 200 | torch device 201 | num_images_per_prompt (`int`): 202 | number of images that should be generated per prompt 203 | do_classifier_free_guidance (`bool`): 204 | whether to use classifier free guidance or not 205 | negative_ prompt (`str` or `List[str]`, *optional*): 206 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 207 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 208 | less than `1`). 209 | prompt_embeds (`torch.FloatTensor`, *optional*): 210 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 211 | provided, text embeddings will be generated from `prompt` input argument. 212 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 213 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 214 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 215 | argument. 216 | """ 217 | if prompt is not None and isinstance(prompt, str): 218 | batch_size = 1 219 | elif prompt is not None and isinstance(prompt, list): 220 | batch_size = len(prompt) 221 | else: 222 | batch_size = prompt_embeds.shape[0] 223 | 224 | if prompt_embeds is None: 225 | # textual inversion: procecss multi-vector tokens if necessary 226 | if isinstance(self, TextualInversionLoaderMixin): 227 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 228 | 229 | text_inputs = self.tokenizer( 230 | prompt, 231 | padding="max_length", 232 | max_length=self.tokenizer.model_max_length, 233 | truncation=True, 234 | return_tensors="pt", 235 | ) 236 | text_input_ids = text_inputs.input_ids 237 | untruncated_ids = self.tokenizer( 238 | prompt, padding="longest", return_tensors="pt" 239 | ).input_ids 240 | 241 | if untruncated_ids.shape[-1] >= text_input_ids.shape[ 242 | -1 243 | ] and not torch.equal(text_input_ids, untruncated_ids): 244 | removed_text = self.tokenizer.batch_decode( 245 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 246 | ) 247 | logger.warning( 248 | "The following part of your input was truncated because CLIP can only handle sequences up to" 249 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 250 | ) 251 | 252 | if ( 253 | hasattr(self.text_encoder.config, "use_attention_mask") 254 | and self.text_encoder.config.use_attention_mask 255 | ): 256 | attention_mask = text_inputs.attention_mask.to(device) 257 | else: 258 | attention_mask = None 259 | 260 | prompt_embeds = self.text_encoder( 261 | text_input_ids.to(device), 262 | attention_mask=attention_mask, 263 | ) 264 | prompt_embeds = prompt_embeds[0] 265 | 266 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 267 | 268 | bs_embed, seq_len, _ = prompt_embeds.shape 269 | # duplicate text embeddings for each generation per prompt, using mps friendly method 270 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 271 | prompt_embeds = prompt_embeds.view( 272 | bs_embed * num_images_per_prompt, seq_len, -1 273 | ) 274 | 275 | # get unconditional embeddings for classifier free guidance 276 | if do_classifier_free_guidance and negative_prompt_embeds is None: 277 | uncond_tokens: List[str] 278 | if negative_prompt is None: 279 | uncond_tokens = [""] * batch_size 280 | elif type(prompt) is not type(negative_prompt): 281 | raise TypeError( 282 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 283 | f" {type(prompt)}." 284 | ) 285 | elif isinstance(negative_prompt, str): 286 | uncond_tokens = [negative_prompt] 287 | elif batch_size != len(negative_prompt): 288 | raise ValueError( 289 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 290 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 291 | " the batch size of `prompt`." 292 | ) 293 | else: 294 | uncond_tokens = negative_prompt 295 | 296 | # textual inversion: procecss multi-vector tokens if necessary 297 | if isinstance(self, TextualInversionLoaderMixin): 298 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 299 | 300 | max_length = prompt_embeds.shape[1] 301 | uncond_input = self.tokenizer( 302 | uncond_tokens, 303 | padding="max_length", 304 | max_length=max_length, 305 | truncation=True, 306 | return_tensors="pt", 307 | ) 308 | 309 | if ( 310 | hasattr(self.text_encoder.config, "use_attention_mask") 311 | and self.text_encoder.config.use_attention_mask 312 | ): 313 | attention_mask = uncond_input.attention_mask.to(device) 314 | else: 315 | attention_mask = None 316 | 317 | negative_prompt_embeds = self.text_encoder( 318 | uncond_input.input_ids.to(device), 319 | attention_mask=attention_mask, 320 | ) 321 | negative_prompt_embeds = negative_prompt_embeds[0] 322 | 323 | if do_classifier_free_guidance: 324 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 325 | seq_len = negative_prompt_embeds.shape[1] 326 | 327 | negative_prompt_embeds = negative_prompt_embeds.to( 328 | dtype=self.text_encoder.dtype, device=device 329 | ) 330 | 331 | negative_prompt_embeds = negative_prompt_embeds.repeat( 332 | 1, num_images_per_prompt, 1 333 | ) 334 | negative_prompt_embeds = negative_prompt_embeds.view( 335 | batch_size * num_images_per_prompt, seq_len, -1 336 | ) 337 | 338 | # For classifier free guidance, we need to do two forward passes. 339 | # Here we concatenate the unconditional and text embeddings into a single batch 340 | # to avoid doing two forward passes 341 | # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] 342 | prompt_embeds = torch.cat( 343 | [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] 344 | ) 345 | 346 | return prompt_embeds 347 | 348 | def prepare_extra_step_kwargs(self, generator, eta): 349 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 350 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 351 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 352 | # and should be between [0, 1] 353 | 354 | accepts_eta = "eta" in set( 355 | inspect.signature(self.scheduler.step).parameters.keys() 356 | ) 357 | extra_step_kwargs = {} 358 | if accepts_eta: 359 | extra_step_kwargs["eta"] = eta 360 | 361 | # check if the scheduler accepts generator 362 | accepts_generator = "generator" in set( 363 | inspect.signature(self.scheduler.step).parameters.keys() 364 | ) 365 | if accepts_generator: 366 | extra_step_kwargs["generator"] = generator 367 | return extra_step_kwargs 368 | 369 | def check_inputs( 370 | self, 371 | prompt, 372 | callback_steps, 373 | negative_prompt=None, 374 | prompt_embeds=None, 375 | negative_prompt_embeds=None, 376 | ): 377 | if (callback_steps is None) or ( 378 | callback_steps is not None 379 | and (not isinstance(callback_steps, int) or callback_steps <= 0) 380 | ): 381 | raise ValueError( 382 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 383 | f" {type(callback_steps)}." 384 | ) 385 | 386 | if prompt is not None and prompt_embeds is not None: 387 | raise ValueError( 388 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 389 | " only forward one of the two." 390 | ) 391 | elif prompt is None and prompt_embeds is None: 392 | raise ValueError( 393 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 394 | ) 395 | elif prompt is not None and ( 396 | not isinstance(prompt, str) and not isinstance(prompt, list) 397 | ): 398 | raise ValueError( 399 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 400 | ) 401 | 402 | if negative_prompt is not None and negative_prompt_embeds is not None: 403 | raise ValueError( 404 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 405 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 406 | ) 407 | 408 | if prompt_embeds is not None and negative_prompt_embeds is not None: 409 | if prompt_embeds.shape != negative_prompt_embeds.shape: 410 | raise ValueError( 411 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 412 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 413 | f" {negative_prompt_embeds.shape}." 414 | ) 415 | 416 | def prepare_latents( 417 | self, 418 | batch_size, 419 | num_channels_latents, 420 | height, 421 | width, 422 | dtype, 423 | device, 424 | generator, 425 | latents=None, 426 | ): 427 | shape = ( 428 | batch_size, 429 | num_channels_latents, 430 | height // self.vae_scale_factor, 431 | width // self.vae_scale_factor, 432 | ) 433 | if isinstance(generator, list) and len(generator) != batch_size: 434 | raise ValueError( 435 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 436 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 437 | ) 438 | 439 | if latents is None: 440 | latents = randn_tensor( 441 | shape, generator=generator, device=device, dtype=dtype 442 | ) 443 | else: 444 | latents = latents.to(device) 445 | 446 | # scale the initial noise by the standard deviation required by the scheduler 447 | latents = latents * self.scheduler.init_noise_sigma 448 | return latents 449 | 450 | def prepare_image_latents( 451 | self, 452 | image, 453 | batch_size, 454 | num_images_per_prompt, 455 | dtype, 456 | device, 457 | do_classifier_free_guidance, 458 | generator=None, 459 | ): 460 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 461 | raise ValueError( 462 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 463 | ) 464 | 465 | image = image.to(device=device, dtype=dtype) 466 | 467 | batch_size = batch_size * num_images_per_prompt 468 | 469 | if image.shape[1] == 4: 470 | image_latents = image 471 | else: 472 | if isinstance(generator, list) and len(generator) != batch_size: 473 | raise ValueError( 474 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 475 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 476 | ) 477 | 478 | if isinstance(generator, list): 479 | image_latents = [ 480 | self.vae.encode(image[i : i + 1]).latent_dist.mode() 481 | for i in range(batch_size) 482 | ] 483 | image_latents = torch.cat(image_latents, dim=0) 484 | else: 485 | image_latents = self.vae.encode(image).latent_dist.mode() 486 | 487 | if ( 488 | batch_size > image_latents.shape[0] 489 | and batch_size % image_latents.shape[0] == 0 490 | ): 491 | # expand image_latents for batch_size 492 | deprecation_message = ( 493 | f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" 494 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 495 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 496 | " your script to pass as many initial images as text prompts to suppress this warning." 497 | ) 498 | deprecate( 499 | "len(prompt) != len(image)", 500 | "1.0.0", 501 | deprecation_message, 502 | standard_warn=False, 503 | ) 504 | additional_image_per_prompt = batch_size // image_latents.shape[0] 505 | image_latents = torch.cat( 506 | [image_latents] * additional_image_per_prompt, dim=0 507 | ) 508 | elif ( 509 | batch_size > image_latents.shape[0] 510 | and batch_size % image_latents.shape[0] != 0 511 | ): 512 | raise ValueError( 513 | f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." 514 | ) 515 | else: 516 | image_latents = torch.cat([image_latents], dim=0) 517 | 518 | if do_classifier_free_guidance: 519 | uncond_image_latents = torch.zeros_like(image_latents) 520 | image_latents = torch.cat( 521 | [image_latents, image_latents, uncond_image_latents], dim=0 522 | ) 523 | 524 | return image_latents 525 | 526 | @torch.no_grad() 527 | def __call__( 528 | self, 529 | height: int, 530 | width: int, 531 | prompt: Union[str, List[str]] = None, 532 | albedo: Optional[ 533 | Union[ 534 | torch.FloatTensor, 535 | PIL.Image.Image, 536 | np.ndarray, 537 | List[torch.FloatTensor], 538 | List[PIL.Image.Image], 539 | List[np.ndarray], 540 | ] 541 | ] = None, 542 | normal: Optional[ 543 | Union[ 544 | torch.FloatTensor, 545 | PIL.Image.Image, 546 | np.ndarray, 547 | List[torch.FloatTensor], 548 | List[PIL.Image.Image], 549 | List[np.ndarray], 550 | ] 551 | ] = None, 552 | roughness: Optional[ 553 | Union[ 554 | torch.FloatTensor, 555 | PIL.Image.Image, 556 | np.ndarray, 557 | List[torch.FloatTensor], 558 | List[PIL.Image.Image], 559 | List[np.ndarray], 560 | ] 561 | ] = None, 562 | metallic: Optional[ 563 | Union[ 564 | torch.FloatTensor, 565 | PIL.Image.Image, 566 | np.ndarray, 567 | List[torch.FloatTensor], 568 | List[PIL.Image.Image], 569 | List[np.ndarray], 570 | ] 571 | ] = None, 572 | irradiance: Optional[ 573 | Union[ 574 | torch.FloatTensor, 575 | PIL.Image.Image, 576 | np.ndarray, 577 | List[torch.FloatTensor], 578 | List[PIL.Image.Image], 579 | List[np.ndarray], 580 | ] 581 | ] = None, 582 | guidance_scale: float = 0.0, 583 | image_guidance_scale: float = 0.0, 584 | guidance_rescale: float = 0.0, 585 | num_inference_steps: int = 100, 586 | required_aovs: List[str] = ["albedo"], 587 | return_predicted_x0s: bool = False, 588 | negative_prompt: Optional[Union[str, List[str]]] = None, 589 | num_images_per_prompt: Optional[int] = 1, 590 | eta: float = 0.0, 591 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 592 | latents: Optional[torch.FloatTensor] = None, 593 | prompt_embeds: Optional[torch.FloatTensor] = None, 594 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 595 | output_type: Optional[str] = "pil", 596 | return_dict: bool = True, 597 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 598 | callback_steps: int = 1, 599 | ): 600 | r""" 601 | The call function to the pipeline for generation. 602 | 603 | Args: 604 | prompt (`str` or `List[str]`, *optional*): 605 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 606 | image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 607 | `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept 608 | image latents as `image`, but if passing latents directly it is not encoded again. 609 | num_inference_steps (`int`, *optional*, defaults to 100): 610 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 611 | expense of slower inference. 612 | guidance_scale (`float`, *optional*, defaults to 7.5): 613 | A higher guidance scale value encourages the model to generate images closely linked to the text 614 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 615 | image_guidance_scale (`float`, *optional*, defaults to 1.5): 616 | Push the generated image towards the inital `image`. Image guidance scale is enabled by setting 617 | `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely 618 | linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a 619 | value of at least `1`. 620 | negative_prompt (`str` or `List[str]`, *optional*): 621 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 622 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 623 | num_images_per_prompt (`int`, *optional*, defaults to 1): 624 | The number of images to generate per prompt. 625 | eta (`float`, *optional*, defaults to 0.0): 626 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 627 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 628 | generator (`torch.Generator`, *optional*): 629 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 630 | generation deterministic. 631 | latents (`torch.FloatTensor`, *optional*): 632 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 633 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 634 | tensor is generated by sampling using the supplied random `generator`. 635 | prompt_embeds (`torch.FloatTensor`, *optional*): 636 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 637 | provided, text embeddings are generated from the `prompt` input argument. 638 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 639 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 640 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 641 | output_type (`str`, *optional*, defaults to `"pil"`): 642 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 643 | return_dict (`bool`, *optional*, defaults to `True`): 644 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 645 | plain tuple. 646 | callback (`Callable`, *optional*): 647 | A function that calls every `callback_steps` steps during inference. The function is called with the 648 | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 649 | callback_steps (`int`, *optional*, defaults to 1): 650 | The frequency at which the `callback` function is called. If not specified, the callback is called at 651 | every step. 652 | 653 | Examples: 654 | 655 | ```py 656 | >>> import PIL 657 | >>> import requests 658 | >>> import torch 659 | >>> from io import BytesIO 660 | 661 | >>> from diffusers import StableDiffusionInstructPix2PixPipeline 662 | 663 | 664 | >>> def download_image(url): 665 | ... response = requests.get(url) 666 | ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") 667 | 668 | 669 | >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" 670 | 671 | >>> image = download_image(img_url).resize((512, 512)) 672 | 673 | >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( 674 | ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 675 | ... ) 676 | >>> pipe = pipe.to("cuda") 677 | 678 | >>> prompt = "make the mountains snowy" 679 | >>> image = pipe(prompt=prompt, image=image).images[0] 680 | ``` 681 | 682 | Returns: 683 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 684 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 685 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 686 | second element is a list of `bool`s indicating whether the corresponding generated image contains 687 | "not-safe-for-work" (nsfw) content. 688 | """ 689 | # 0. Check inputs 690 | self.check_inputs( 691 | prompt, 692 | callback_steps, 693 | negative_prompt, 694 | prompt_embeds, 695 | negative_prompt_embeds, 696 | ) 697 | 698 | # 1. Define call parameters 699 | if prompt is not None and isinstance(prompt, str): 700 | batch_size = 1 701 | elif prompt is not None and isinstance(prompt, list): 702 | batch_size = len(prompt) 703 | else: 704 | batch_size = prompt_embeds.shape[0] 705 | 706 | device = self._execution_device 707 | do_classifier_free_guidance = ( 708 | guidance_scale >= 1.0 and image_guidance_scale >= 1.0 709 | ) 710 | # check if scheduler is in sigmas space 711 | scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") 712 | 713 | # 2. Encode input prompt 714 | prompt_embeds = self._encode_prompt( 715 | prompt, 716 | device, 717 | num_images_per_prompt, 718 | do_classifier_free_guidance, 719 | negative_prompt, 720 | prompt_embeds=prompt_embeds, 721 | negative_prompt_embeds=negative_prompt_embeds, 722 | ) 723 | 724 | # 3. Preprocess image 725 | # For normal, the preprocessing does nothing 726 | # For others, the preprocessing remap the values to [-1, 1] 727 | preprocessed_aovs = {} 728 | for aov_name in required_aovs: 729 | if aov_name == "albedo": 730 | if albedo is not None: 731 | preprocessed_aovs[aov_name] = self.image_processor.preprocess( 732 | albedo 733 | ) 734 | else: 735 | preprocessed_aovs[aov_name] = None 736 | 737 | if aov_name == "normal": 738 | if normal is not None: 739 | preprocessed_aovs[aov_name] = ( 740 | self.image_processor.preprocess_normal(normal) 741 | ) 742 | else: 743 | preprocessed_aovs[aov_name] = None 744 | 745 | if aov_name == "roughness": 746 | if roughness is not None: 747 | preprocessed_aovs[aov_name] = self.image_processor.preprocess( 748 | roughness 749 | ) 750 | else: 751 | preprocessed_aovs[aov_name] = None 752 | if aov_name == "metallic": 753 | if metallic is not None: 754 | preprocessed_aovs[aov_name] = self.image_processor.preprocess( 755 | metallic 756 | ) 757 | else: 758 | preprocessed_aovs[aov_name] = None 759 | if aov_name == "irradiance": 760 | if irradiance is not None: 761 | preprocessed_aovs[aov_name] = self.image_processor.preprocess( 762 | irradiance 763 | ) 764 | else: 765 | preprocessed_aovs[aov_name] = None 766 | 767 | # 4. set timesteps 768 | self.scheduler.set_timesteps(num_inference_steps, device=device) 769 | timesteps = self.scheduler.timesteps 770 | 771 | # 5. Prepare latent variables 772 | num_channels_latents = self.vae.config.latent_channels 773 | latents = self.prepare_latents( 774 | batch_size * num_images_per_prompt, 775 | num_channels_latents, 776 | height, 777 | width, 778 | prompt_embeds.dtype, 779 | device, 780 | generator, 781 | latents, 782 | ) 783 | 784 | height_latent, width_latent = latents.shape[-2:] 785 | 786 | # 6. Prepare Image latents 787 | image_latents = [] 788 | # Magicial scaling factors for each AOV (calculated from the training data) 789 | scaling_factors = { 790 | "albedo": 0.17301377137652138, 791 | "normal": 0.17483895473058078, 792 | "roughness": 0.1680724853626448, 793 | "metallic": 0.13135013390855135, 794 | } 795 | for aov_name, aov in preprocessed_aovs.items(): 796 | if aov is None: 797 | image_latent = torch.zeros( 798 | batch_size, 799 | num_channels_latents, 800 | height_latent, 801 | width_latent, 802 | dtype=prompt_embeds.dtype, 803 | device=device, 804 | ) 805 | if aov_name == "irradiance": 806 | image_latent = image_latent[:, 0:3] 807 | if do_classifier_free_guidance: 808 | image_latents.append( 809 | torch.cat([image_latent, image_latent, image_latent], dim=0) 810 | ) 811 | else: 812 | image_latents.append(image_latent) 813 | else: 814 | if aov_name == "irradiance": 815 | image_latent = F.interpolate( 816 | aov.to(device=device, dtype=prompt_embeds.dtype), 817 | size=(height_latent, width_latent), 818 | mode="bilinear", 819 | align_corners=False, 820 | antialias=True, 821 | ) 822 | if do_classifier_free_guidance: 823 | uncond_image_latent = torch.zeros_like(image_latent) 824 | image_latent = torch.cat( 825 | [image_latent, image_latent, uncond_image_latent], dim=0 826 | ) 827 | else: 828 | scaling_factor = scaling_factors[aov_name] 829 | image_latent = ( 830 | self.prepare_image_latents( 831 | aov, 832 | batch_size, 833 | num_images_per_prompt, 834 | prompt_embeds.dtype, 835 | device, 836 | do_classifier_free_guidance, 837 | generator, 838 | ) 839 | * scaling_factor 840 | ) 841 | image_latents.append(image_latent) 842 | image_latents = torch.cat(image_latents, dim=1) 843 | 844 | # 7. Check that shapes of latents and image match the UNet channels 845 | num_channels_image = image_latents.shape[1] 846 | if num_channels_latents + num_channels_image != self.unet.config.in_channels: 847 | raise ValueError( 848 | f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" 849 | f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" 850 | f" `num_channels_image`: {num_channels_image} " 851 | f" = {num_channels_latents+num_channels_image}. Please verify the config of" 852 | " `pipeline.unet` or your `image` input." 853 | ) 854 | 855 | # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 856 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 857 | 858 | predicted_x0s = [] 859 | 860 | # 9. Denoising loop 861 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 862 | with self.progress_bar(total=num_inference_steps) as progress_bar: 863 | for i, t in enumerate(timesteps): 864 | # Expand the latents if we are doing classifier free guidance. 865 | # The latents are expanded 3 times because for pix2pix the guidance\ 866 | # is applied for both the text and the input image. 867 | latent_model_input = ( 868 | torch.cat([latents] * 3) if do_classifier_free_guidance else latents 869 | ) 870 | 871 | # concat latents, image_latents in the channel dimension 872 | scaled_latent_model_input = self.scheduler.scale_model_input( 873 | latent_model_input, t 874 | ) 875 | scaled_latent_model_input = torch.cat( 876 | [scaled_latent_model_input, image_latents], dim=1 877 | ) 878 | 879 | # predict the noise residual 880 | noise_pred = self.unet( 881 | scaled_latent_model_input, 882 | t, 883 | encoder_hidden_states=prompt_embeds, 884 | return_dict=False, 885 | )[0] 886 | 887 | # perform guidance 888 | if do_classifier_free_guidance: 889 | ( 890 | noise_pred_text, 891 | noise_pred_image, 892 | noise_pred_uncond, 893 | ) = noise_pred.chunk(3) 894 | noise_pred = ( 895 | noise_pred_uncond 896 | + guidance_scale * (noise_pred_text - noise_pred_image) 897 | + image_guidance_scale * (noise_pred_image - noise_pred_uncond) 898 | ) 899 | 900 | if do_classifier_free_guidance and guidance_rescale > 0.0: 901 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 902 | noise_pred = rescale_noise_cfg( 903 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale 904 | ) 905 | 906 | # compute the previous noisy sample x_t -> x_t-1 907 | output = self.scheduler.step( 908 | noise_pred, t, latents, **extra_step_kwargs, return_dict=True 909 | ) 910 | 911 | latents = output[0] 912 | 913 | if return_predicted_x0s: 914 | predicted_x0s.append(output[1]) 915 | 916 | # call the callback, if provided 917 | if i == len(timesteps) - 1 or ( 918 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 919 | ): 920 | progress_bar.update() 921 | if callback is not None and i % callback_steps == 0: 922 | callback(i, t, latents) 923 | 924 | if not output_type == "latent": 925 | image = self.vae.decode( 926 | latents / self.vae.config.scaling_factor, return_dict=False 927 | )[0] 928 | 929 | if return_predicted_x0s: 930 | predicted_x0_images = [ 931 | self.vae.decode( 932 | predicted_x0 / self.vae.config.scaling_factor, return_dict=False 933 | )[0] 934 | for predicted_x0 in predicted_x0s 935 | ] 936 | else: 937 | image = latents 938 | predicted_x0_images = predicted_x0s 939 | 940 | do_denormalize = [True] * image.shape[0] 941 | 942 | image = self.image_processor.postprocess( 943 | image, output_type=output_type, do_denormalize=do_denormalize 944 | ) 945 | 946 | if return_predicted_x0s: 947 | predicted_x0_images = [ 948 | self.image_processor.postprocess( 949 | predicted_x0_image, 950 | output_type=output_type, 951 | do_denormalize=do_denormalize, 952 | ) 953 | for predicted_x0_image in predicted_x0_images 954 | ] 955 | 956 | # Offload last model to CPU 957 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 958 | self.final_offload_hook.offload() 959 | 960 | if not return_dict: 961 | return image 962 | 963 | if return_predicted_x0s: 964 | return StableDiffusionAOVPipelineOutput( 965 | images=image, predicted_x0_images=predicted_x0_images 966 | ) 967 | else: 968 | return StableDiffusionAOVPipelineOutput(images=image) 969 | -------------------------------------------------------------------------------- /x2rgb_inpainting/gradio_demo_x2rgb_inpainting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 5 | 6 | import gradio as gr 7 | import numpy as np 8 | import torch 9 | from diffusers import DDIMScheduler 10 | from load_image import load_exr_image, load_ldr_image 11 | from pipeline_x2rgb_inpainting import StableDiffusionAOVDropoutPipeline 12 | 13 | current_directory = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | 16 | def get_x2rgb_demo(): 17 | # Load pipeline 18 | pipe = StableDiffusionAOVDropoutPipeline.from_pretrained( 19 | "zheng95z/x-to-rgb-inpainting", 20 | torch_dtype=torch.float16, 21 | cache_dir=os.path.join(current_directory, "model_cache"), 22 | ).to("cuda") 23 | pipe.scheduler = DDIMScheduler.from_config( 24 | pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" 25 | ) 26 | pipe.set_progress_bar_config(disable=True) 27 | pipe.to("cuda") 28 | 29 | # Augmentation 30 | def callback( 31 | albedo, 32 | normal, 33 | roughness, 34 | metallic, 35 | irradiance, 36 | mask, 37 | photo, 38 | prompt, 39 | seed, 40 | inference_step, 41 | num_samples, 42 | guidance_scale, 43 | image_guidance_scale, 44 | ): 45 | if albedo is None: 46 | albedo_image = None 47 | elif albedo.name.endswith(".exr"): 48 | albedo_image = load_exr_image(albedo.name, clamp=True).to("cuda") 49 | elif ( 50 | albedo.name.endswith(".png") 51 | or albedo.name.endswith(".jpg") 52 | or albedo.name.endswith(".jpeg") 53 | ): 54 | albedo_image = load_ldr_image(albedo.name, from_srgb=True).to("cuda") 55 | 56 | if normal is None: 57 | normal_image = None 58 | elif normal.name.endswith(".exr"): 59 | normal_image = load_exr_image(normal.name, normalize=True).to("cuda") 60 | elif ( 61 | normal.name.endswith(".png") 62 | or normal.name.endswith(".jpg") 63 | or normal.name.endswith(".jpeg") 64 | ): 65 | normal_image = load_ldr_image(normal.name, normalize=True).to("cuda") 66 | 67 | if roughness is None: 68 | roughness_image = None 69 | elif roughness.name.endswith(".exr"): 70 | roughness_image = load_exr_image(roughness.name, clamp=True).to("cuda") 71 | elif ( 72 | roughness.name.endswith(".png") 73 | or roughness.name.endswith(".jpg") 74 | or roughness.name.endswith(".jpeg") 75 | ): 76 | roughness_image = load_ldr_image(roughness.name, clamp=True).to("cuda") 77 | 78 | if metallic is None: 79 | metallic_image = None 80 | elif metallic.name.endswith(".exr"): 81 | metallic_image = load_exr_image(metallic.name, clamp=True).to("cuda") 82 | elif ( 83 | metallic.name.endswith(".png") 84 | or metallic.name.endswith(".jpg") 85 | or metallic.name.endswith(".jpeg") 86 | ): 87 | metallic_image = load_ldr_image(metallic.name, clamp=True).to("cuda") 88 | 89 | if irradiance is None: 90 | irradiance_image = None 91 | elif irradiance.name.endswith(".exr"): 92 | irradiance_image = load_exr_image( 93 | irradiance.name, tonemaping=True, clamp=True 94 | ).to("cuda") 95 | elif ( 96 | irradiance.name.endswith(".png") 97 | or irradiance.name.endswith(".jpg") 98 | or irradiance.name.endswith(".jpeg") 99 | ): 100 | irradiance_image = load_ldr_image( 101 | irradiance.name, from_srgb=True, clamp=True 102 | ).to("cuda") 103 | 104 | generator = torch.Generator(device="cuda").manual_seed(seed) 105 | 106 | height = 768 107 | width = 768 108 | # Check if any of the given images are not None 109 | images = [ 110 | albedo_image, 111 | normal_image, 112 | roughness_image, 113 | metallic_image, 114 | irradiance_image, 115 | ] 116 | 117 | assert photo is not None 118 | assert mask is not None 119 | if mask.name.endswith(".exr"): 120 | mask = load_exr_image(mask.name, clamp=True).to("cuda")[0:1] 121 | elif ( 122 | mask.name.endswith(".png") 123 | or mask.name.endswith(".jpg") 124 | or mask.name.endswith(".jpeg") 125 | ): 126 | mask = load_ldr_image(mask.name).to("cuda")[0:1] 127 | 128 | mask = 1.0 - mask 129 | 130 | if photo.name.endswith(".exr"): 131 | photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda") 132 | elif ( 133 | photo.name.endswith(".png") 134 | or photo.name.endswith(".jpg") 135 | or photo.name.endswith(".jpeg") 136 | ): 137 | photo = load_ldr_image(photo.name, from_srgb=True).to("cuda") 138 | 139 | for img in images: 140 | if img is not None: 141 | height = img.shape[1] 142 | width = img.shape[2] 143 | break 144 | 145 | masked_photo = photo * mask 146 | 147 | required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] 148 | return_list = [] 149 | for i in range(num_samples): 150 | res = pipe( 151 | prompt=prompt, 152 | albedo=albedo_image, 153 | normal=normal_image, 154 | roughness=roughness_image, 155 | metallic=metallic_image, 156 | irradiance=irradiance_image, 157 | mask=mask, 158 | masked_image=masked_photo, 159 | photo=photo, 160 | num_inference_steps=inference_step, 161 | height=height, 162 | width=width, 163 | generator=generator, 164 | required_aovs=required_aovs, 165 | guidance_scale=guidance_scale, 166 | image_guidance_scale=image_guidance_scale, 167 | guidance_rescale=0.7, 168 | output_type="np", 169 | ).images 170 | generated_image = res[0][0] 171 | masked_photo_vae = res[1][0] 172 | photo_vae = res[2][0] 173 | generated_image = (generated_image, f"Generated Image {i}") 174 | return_list.append(generated_image) 175 | 176 | masked_photo_vae = (masked_photo_vae, "Masked photo") 177 | photo_vae = (photo_vae, "Photo") 178 | return_list.append(masked_photo_vae) 179 | return_list.append(photo_vae) 180 | 181 | return return_list 182 | 183 | block = gr.Blocks() 184 | with block: 185 | with gr.Row(): 186 | gr.Markdown( 187 | "## Model X -> RGB (Intrinsic channels -> realistic image) inpainting" 188 | ) 189 | with gr.Row(): 190 | # Input side 191 | with gr.Column(): 192 | gr.Markdown("### Given intrinsic channels") 193 | albedo = gr.File(label="Albedo", file_types=[".exr", ".png", ".jpg"]) 194 | normal = gr.File(label="Normal", file_types=[".exr", ".png", ".jpg"]) 195 | roughness = gr.File( 196 | label="Roughness", file_types=[".exr", ".png", ".jpg"] 197 | ) 198 | metallic = gr.File( 199 | label="Metallic", file_types=[".exr", ".png", ".jpg"] 200 | ) 201 | irradiance = gr.File( 202 | label="Irradiance", file_types=[".exr", ".png", ".jpg"] 203 | ) 204 | mask = gr.File(label="Mask", file_types=[".exr", ".png", ".jpg"]) 205 | photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"]) 206 | 207 | gr.Markdown("### Parameters") 208 | prompt = gr.Textbox(label="Prompt") 209 | run_button = gr.Button(label="Run") 210 | with gr.Accordion("Advanced options", open=False): 211 | seed = gr.Slider( 212 | label="Seed", 213 | minimum=-1, 214 | maximum=2147483647, 215 | step=1, 216 | randomize=True, 217 | ) 218 | inference_step = gr.Slider( 219 | label="Inference Step", 220 | minimum=1, 221 | maximum=200, 222 | step=1, 223 | value=50, 224 | ) 225 | num_samples = gr.Slider( 226 | label="Samples", 227 | minimum=1, 228 | maximum=100, 229 | step=1, 230 | value=1, 231 | ) 232 | guidance_scale = gr.Slider( 233 | label="Guidance Scale", 234 | minimum=0.0, 235 | maximum=10.0, 236 | step=0.1, 237 | value=7.5, 238 | ) 239 | image_guidance_scale = gr.Slider( 240 | label="Image Guidance Scale", 241 | minimum=0.0, 242 | maximum=10.0, 243 | step=0.1, 244 | value=1.5, 245 | ) 246 | 247 | # Output side 248 | with gr.Column(): 249 | gr.Markdown("### Output Gallery") 250 | result_gallery = gr.Gallery( 251 | label="Output", 252 | show_label=False, 253 | elem_id="gallery", 254 | columns=2, 255 | ) 256 | 257 | inputs = [ 258 | albedo, 259 | normal, 260 | roughness, 261 | metallic, 262 | irradiance, 263 | mask, 264 | photo, 265 | prompt, 266 | seed, 267 | inference_step, 268 | num_samples, 269 | guidance_scale, 270 | image_guidance_scale, 271 | ] 272 | run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True) 273 | 274 | return block 275 | 276 | 277 | if __name__ == "__main__": 278 | demo = get_x2rgb_demo() 279 | demo.queue(max_size=1) 280 | demo.launch() 281 | -------------------------------------------------------------------------------- /x2rgb_inpainting/load_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import torch 5 | 6 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 7 | import numpy as np 8 | 9 | 10 | def convert_rgb_2_XYZ(rgb): 11 | # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html 12 | # rgb: (h, w, 3) 13 | # XYZ: (h, w, 3) 14 | XYZ = torch.ones_like(rgb) 15 | XYZ[:, :, 0] = ( 16 | 0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2] 17 | ) 18 | XYZ[:, :, 1] = ( 19 | 0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2] 20 | ) 21 | XYZ[:, :, 2] = ( 22 | 0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2] 23 | ) 24 | return XYZ 25 | 26 | 27 | def convert_XYZ_2_Yxy(XYZ): 28 | # XYZ: (h, w, 3) 29 | # Yxy: (h, w, 3) 30 | Yxy = torch.ones_like(XYZ) 31 | Yxy[:, :, 0] = XYZ[:, :, 1] 32 | sum = torch.sum(XYZ, dim=2) 33 | inv_sum = 1.0 / torch.clamp(sum, min=1e-4) 34 | Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum 35 | Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum 36 | return Yxy 37 | 38 | 39 | def convert_rgb_2_Yxy(rgb): 40 | # rgb: (h, w, 3) 41 | # Yxy: (h, w, 3) 42 | return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb)) 43 | 44 | 45 | def convert_XYZ_2_rgb(XYZ): 46 | # XYZ: (h, w, 3) 47 | # rgb: (h, w, 3) 48 | rgb = torch.ones_like(XYZ) 49 | rgb[:, :, 0] = ( 50 | 3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2] 51 | ) 52 | rgb[:, :, 1] = ( 53 | -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2] 54 | ) 55 | rgb[:, :, 2] = ( 56 | 0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2] 57 | ) 58 | return rgb 59 | 60 | 61 | def convert_Yxy_2_XYZ(Yxy): 62 | # Yxy: (h, w, 3) 63 | # XYZ: (h, w, 3) 64 | XYZ = torch.ones_like(Yxy) 65 | XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0] 66 | XYZ[:, :, 1] = Yxy[:, :, 0] 67 | XYZ[:, :, 2] = ( 68 | (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2]) 69 | / torch.clamp(Yxy[:, :, 2], min=1e-4) 70 | * Yxy[:, :, 0] 71 | ) 72 | return XYZ 73 | 74 | 75 | def convert_Yxy_2_rgb(Yxy): 76 | # Yxy: (h, w, 3) 77 | # rgb: (h, w, 3) 78 | return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy)) 79 | 80 | 81 | def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False): 82 | # Load png or jpg image 83 | image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) 84 | image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c) 85 | image[~torch.isfinite(image)] = 0 86 | if from_srgb: 87 | # Convert from sRGB to linear RGB 88 | image = image**2.2 89 | if clamp: 90 | image = torch.clamp(image, min=0.0, max=1.0) 91 | if normalize: 92 | # Normalize to [-1, 1] 93 | image = image * 2.0 - 1.0 94 | image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6) 95 | return image.permute(2, 0, 1) # returns (c, h, w) 96 | 97 | 98 | def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False): 99 | image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB) 100 | image = torch.from_numpy(image.astype("float32")) # (h, w, c) 101 | image[~torch.isfinite(image)] = 0 102 | if tonemaping: 103 | # Exposure adjuestment 104 | image_Yxy = convert_rgb_2_Yxy(image) 105 | lum = ( 106 | image[:, :, 0:1] * 0.2125 107 | + image[:, :, 1:2] * 0.7154 108 | + image[:, :, 2:3] * 0.0721 109 | ) 110 | lum = torch.log(torch.clamp(lum, min=1e-6)) 111 | lum_mean = torch.exp(torch.mean(lum)) 112 | lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6) 113 | image_Yxy[:, :, 0:1] = lp 114 | image = convert_Yxy_2_rgb(image_Yxy) 115 | if clamp: 116 | image = torch.clamp(image, min=0.0, max=1.0) 117 | if normalize: 118 | image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6) 119 | return image.permute(2, 0, 1) # returns (c, h, w) 120 | -------------------------------------------------------------------------------- /x2rgb_inpainting/pipeline_x2rgb_inpainting.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from dataclasses import dataclass 4 | from typing import Callable, List, Optional, Union 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | import torch.nn.functional as F 10 | from diffusers.configuration_utils import register_to_config 11 | from diffusers.image_processor import VaeImageProcessor 12 | from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin 13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 14 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 15 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( 16 | rescale_noise_cfg, 17 | ) 18 | from diffusers.schedulers import KarrasDiffusionSchedulers 19 | from diffusers.utils import CONFIG_NAME, BaseOutput, deprecate, logging, randn_tensor 20 | from packaging import version 21 | from transformers import CLIPTextModel, CLIPTokenizer 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | 26 | class VaeImageProcrssorAOV(VaeImageProcessor): 27 | """ 28 | Image processor for VAE AOV. 29 | 30 | Args: 31 | do_resize (`bool`, *optional*, defaults to `True`): 32 | Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. 33 | vae_scale_factor (`int`, *optional*, defaults to `8`): 34 | VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. 35 | resample (`str`, *optional*, defaults to `lanczos`): 36 | Resampling filter to use when resizing the image. 37 | do_normalize (`bool`, *optional*, defaults to `True`): 38 | Whether to normalize the image to [-1,1]. 39 | """ 40 | 41 | config_name = CONFIG_NAME 42 | 43 | @register_to_config 44 | def __init__( 45 | self, 46 | do_resize: bool = True, 47 | vae_scale_factor: int = 8, 48 | resample: str = "lanczos", 49 | do_normalize: bool = True, 50 | ): 51 | super().__init__() 52 | 53 | def postprocess( 54 | self, 55 | image: torch.FloatTensor, 56 | output_type: str = "pil", 57 | do_denormalize: Optional[List[bool]] = None, 58 | do_gamma_correction: bool = True, 59 | ): 60 | if not isinstance(image, torch.Tensor): 61 | raise ValueError( 62 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" 63 | ) 64 | if output_type not in ["latent", "pt", "np", "pil"]: 65 | deprecation_message = ( 66 | f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " 67 | "`pil`, `np`, `pt`, `latent`" 68 | ) 69 | deprecate( 70 | "Unsupported output_type", 71 | "1.0.0", 72 | deprecation_message, 73 | standard_warn=False, 74 | ) 75 | output_type = "np" 76 | 77 | if output_type == "latent": 78 | return image 79 | 80 | if do_denormalize is None: 81 | do_denormalize = [self.config.do_normalize] * image.shape[0] 82 | 83 | image = torch.stack( 84 | [ 85 | self.denormalize(image[i]) if do_denormalize[i] else image[i] 86 | for i in range(image.shape[0]) 87 | ] 88 | ) 89 | 90 | # Gamma correction 91 | if do_gamma_correction: 92 | image = torch.pow(image, 1.0 / 2.2) 93 | 94 | if output_type == "pt": 95 | return image 96 | 97 | image = self.pt_to_numpy(image) 98 | 99 | if output_type == "np": 100 | return image 101 | 102 | if output_type == "pil": 103 | return self.numpy_to_pil(image) 104 | 105 | def preprocess_normal( 106 | self, 107 | image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], 108 | height: Optional[int] = None, 109 | width: Optional[int] = None, 110 | ) -> torch.Tensor: 111 | image = torch.stack([image], axis=0) 112 | return image 113 | 114 | 115 | @dataclass 116 | class StableDiffusionAOVPipelineOutput(BaseOutput): 117 | """ 118 | Output class for Stable Diffusion AOV pipelines. 119 | 120 | Args: 121 | images (`List[PIL.Image.Image]` or `np.ndarray`) 122 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 123 | num_channels)`. 124 | nsfw_content_detected (`List[bool]`) 125 | List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or 126 | `None` if safety checking could not be performed. 127 | """ 128 | 129 | images: Union[List[PIL.Image.Image], np.ndarray] 130 | predicted_x0_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] = None 131 | 132 | 133 | class StableDiffusionAOVDropoutPipeline( 134 | DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin 135 | ): 136 | r""" 137 | Pipeline for AOVs. 138 | 139 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 140 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 141 | 142 | The pipeline also inherits the following loading methods: 143 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 144 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 145 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 146 | 147 | Args: 148 | vae ([`AutoencoderKL`]): 149 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 150 | text_encoder ([`~transformers.CLIPTextModel`]): 151 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 152 | tokenizer ([`~transformers.CLIPTokenizer`]): 153 | A `CLIPTokenizer` to tokenize text. 154 | unet ([`UNet2DConditionModel`]): 155 | A `UNet2DConditionModel` to denoise the encoded image latents. 156 | scheduler ([`SchedulerMixin`]): 157 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 158 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | vae: AutoencoderKL, 164 | text_encoder: CLIPTextModel, 165 | tokenizer: CLIPTokenizer, 166 | unet: UNet2DConditionModel, 167 | scheduler: KarrasDiffusionSchedulers, 168 | ): 169 | super().__init__() 170 | 171 | self.register_modules( 172 | vae=vae, 173 | text_encoder=text_encoder, 174 | tokenizer=tokenizer, 175 | unet=unet, 176 | scheduler=scheduler, 177 | ) 178 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 179 | self.image_processor = VaeImageProcrssorAOV( 180 | vae_scale_factor=self.vae_scale_factor 181 | ) 182 | self.register_to_config() 183 | 184 | def _encode_prompt( 185 | self, 186 | prompt, 187 | device, 188 | num_images_per_prompt, 189 | do_classifier_free_guidance, 190 | negative_prompt=None, 191 | prompt_embeds: Optional[torch.FloatTensor] = None, 192 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 193 | ): 194 | r""" 195 | Encodes the prompt into text encoder hidden states. 196 | 197 | Args: 198 | prompt (`str` or `List[str]`, *optional*): 199 | prompt to be encoded 200 | device: (`torch.device`): 201 | torch device 202 | num_images_per_prompt (`int`): 203 | number of images that should be generated per prompt 204 | do_classifier_free_guidance (`bool`): 205 | whether to use classifier free guidance or not 206 | negative_ prompt (`str` or `List[str]`, *optional*): 207 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 208 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 209 | less than `1`). 210 | prompt_embeds (`torch.FloatTensor`, *optional*): 211 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 212 | provided, text embeddings will be generated from `prompt` input argument. 213 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 214 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 215 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 216 | argument. 217 | """ 218 | if prompt is not None and isinstance(prompt, str): 219 | batch_size = 1 220 | elif prompt is not None and isinstance(prompt, list): 221 | batch_size = len(prompt) 222 | else: 223 | batch_size = prompt_embeds.shape[0] 224 | 225 | if prompt_embeds is None: 226 | # textual inversion: procecss multi-vector tokens if necessary 227 | if isinstance(self, TextualInversionLoaderMixin): 228 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 229 | 230 | text_inputs = self.tokenizer( 231 | prompt, 232 | padding="max_length", 233 | max_length=self.tokenizer.model_max_length, 234 | truncation=True, 235 | return_tensors="pt", 236 | ) 237 | text_input_ids = text_inputs.input_ids 238 | untruncated_ids = self.tokenizer( 239 | prompt, padding="longest", return_tensors="pt" 240 | ).input_ids 241 | 242 | if untruncated_ids.shape[-1] >= text_input_ids.shape[ 243 | -1 244 | ] and not torch.equal(text_input_ids, untruncated_ids): 245 | removed_text = self.tokenizer.batch_decode( 246 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 247 | ) 248 | logger.warning( 249 | "The following part of your input was truncated because CLIP can only handle sequences up to" 250 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 251 | ) 252 | 253 | if ( 254 | hasattr(self.text_encoder.config, "use_attention_mask") 255 | and self.text_encoder.config.use_attention_mask 256 | ): 257 | attention_mask = text_inputs.attention_mask.to(device) 258 | else: 259 | attention_mask = None 260 | 261 | prompt_embeds = self.text_encoder( 262 | text_input_ids.to(device), 263 | attention_mask=attention_mask, 264 | ) 265 | prompt_embeds = prompt_embeds[0] 266 | 267 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 268 | 269 | bs_embed, seq_len, _ = prompt_embeds.shape 270 | # duplicate text embeddings for each generation per prompt, using mps friendly method 271 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 272 | prompt_embeds = prompt_embeds.view( 273 | bs_embed * num_images_per_prompt, seq_len, -1 274 | ) 275 | 276 | # get unconditional embeddings for classifier free guidance 277 | if do_classifier_free_guidance and negative_prompt_embeds is None: 278 | uncond_tokens: List[str] 279 | if negative_prompt is None: 280 | uncond_tokens = [""] * batch_size 281 | elif type(prompt) is not type(negative_prompt): 282 | raise TypeError( 283 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 284 | f" {type(prompt)}." 285 | ) 286 | elif isinstance(negative_prompt, str): 287 | uncond_tokens = [negative_prompt] 288 | elif batch_size != len(negative_prompt): 289 | raise ValueError( 290 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 291 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 292 | " the batch size of `prompt`." 293 | ) 294 | else: 295 | uncond_tokens = negative_prompt 296 | 297 | # textual inversion: procecss multi-vector tokens if necessary 298 | if isinstance(self, TextualInversionLoaderMixin): 299 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 300 | 301 | max_length = prompt_embeds.shape[1] 302 | uncond_input = self.tokenizer( 303 | uncond_tokens, 304 | padding="max_length", 305 | max_length=max_length, 306 | truncation=True, 307 | return_tensors="pt", 308 | ) 309 | 310 | if ( 311 | hasattr(self.text_encoder.config, "use_attention_mask") 312 | and self.text_encoder.config.use_attention_mask 313 | ): 314 | attention_mask = uncond_input.attention_mask.to(device) 315 | else: 316 | attention_mask = None 317 | 318 | negative_prompt_embeds = self.text_encoder( 319 | uncond_input.input_ids.to(device), 320 | attention_mask=attention_mask, 321 | ) 322 | negative_prompt_embeds = negative_prompt_embeds[0] 323 | 324 | if do_classifier_free_guidance: 325 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 326 | seq_len = negative_prompt_embeds.shape[1] 327 | 328 | negative_prompt_embeds = negative_prompt_embeds.to( 329 | dtype=self.text_encoder.dtype, device=device 330 | ) 331 | 332 | negative_prompt_embeds = negative_prompt_embeds.repeat( 333 | 1, num_images_per_prompt, 1 334 | ) 335 | negative_prompt_embeds = negative_prompt_embeds.view( 336 | batch_size * num_images_per_prompt, seq_len, -1 337 | ) 338 | 339 | # For classifier free guidance, we need to do two forward passes. 340 | # Here we concatenate the unconditional and text embeddings into a single batch 341 | # to avoid doing two forward passes 342 | # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] 343 | prompt_embeds = torch.cat( 344 | [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] 345 | ) 346 | 347 | return prompt_embeds 348 | 349 | def prepare_extra_step_kwargs(self, generator, eta): 350 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 351 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 352 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 353 | # and should be between [0, 1] 354 | 355 | accepts_eta = "eta" in set( 356 | inspect.signature(self.scheduler.step).parameters.keys() 357 | ) 358 | extra_step_kwargs = {} 359 | if accepts_eta: 360 | extra_step_kwargs["eta"] = eta 361 | 362 | # check if the scheduler accepts generator 363 | accepts_generator = "generator" in set( 364 | inspect.signature(self.scheduler.step).parameters.keys() 365 | ) 366 | if accepts_generator: 367 | extra_step_kwargs["generator"] = generator 368 | return extra_step_kwargs 369 | 370 | def check_inputs( 371 | self, 372 | prompt, 373 | callback_steps, 374 | negative_prompt=None, 375 | prompt_embeds=None, 376 | negative_prompt_embeds=None, 377 | ): 378 | if (callback_steps is None) or ( 379 | callback_steps is not None 380 | and (not isinstance(callback_steps, int) or callback_steps <= 0) 381 | ): 382 | raise ValueError( 383 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 384 | f" {type(callback_steps)}." 385 | ) 386 | 387 | if prompt is not None and prompt_embeds is not None: 388 | raise ValueError( 389 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 390 | " only forward one of the two." 391 | ) 392 | elif prompt is None and prompt_embeds is None: 393 | raise ValueError( 394 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 395 | ) 396 | elif prompt is not None and ( 397 | not isinstance(prompt, str) and not isinstance(prompt, list) 398 | ): 399 | raise ValueError( 400 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 401 | ) 402 | 403 | if negative_prompt is not None and negative_prompt_embeds is not None: 404 | raise ValueError( 405 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 406 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 407 | ) 408 | 409 | if prompt_embeds is not None and negative_prompt_embeds is not None: 410 | if prompt_embeds.shape != negative_prompt_embeds.shape: 411 | raise ValueError( 412 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 413 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 414 | f" {negative_prompt_embeds.shape}." 415 | ) 416 | 417 | def prepare_latents( 418 | self, 419 | batch_size, 420 | num_channels_latents, 421 | height, 422 | width, 423 | dtype, 424 | device, 425 | generator, 426 | latents=None, 427 | ): 428 | shape = ( 429 | batch_size, 430 | num_channels_latents, 431 | height // self.vae_scale_factor, 432 | width // self.vae_scale_factor, 433 | ) 434 | if isinstance(generator, list) and len(generator) != batch_size: 435 | raise ValueError( 436 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 437 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 438 | ) 439 | 440 | if latents is None: 441 | latents = randn_tensor( 442 | shape, generator=generator, device=device, dtype=dtype 443 | ) 444 | else: 445 | latents = latents.to(device) 446 | 447 | # scale the initial noise by the standard deviation required by the scheduler 448 | latents = latents * self.scheduler.init_noise_sigma 449 | return latents 450 | 451 | def prepare_image_latents( 452 | self, 453 | image, 454 | batch_size, 455 | num_images_per_prompt, 456 | dtype, 457 | device, 458 | do_classifier_free_guidance, 459 | generator=None, 460 | ): 461 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 462 | raise ValueError( 463 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 464 | ) 465 | 466 | image = image.to(device=device, dtype=dtype) 467 | 468 | batch_size = batch_size * num_images_per_prompt 469 | 470 | if image.shape[1] == 4: 471 | image_latents = image 472 | else: 473 | if isinstance(generator, list) and len(generator) != batch_size: 474 | raise ValueError( 475 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 476 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 477 | ) 478 | 479 | if isinstance(generator, list): 480 | image_latents = [ 481 | self.vae.encode(image[i : i + 1]).latent_dist.mode() 482 | for i in range(batch_size) 483 | ] 484 | image_latents = torch.cat(image_latents, dim=0) 485 | else: 486 | image_latents = self.vae.encode(image).latent_dist.mode() 487 | 488 | if ( 489 | batch_size > image_latents.shape[0] 490 | and batch_size % image_latents.shape[0] == 0 491 | ): 492 | # expand image_latents for batch_size 493 | deprecation_message = ( 494 | f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" 495 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 496 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 497 | " your script to pass as many initial images as text prompts to suppress this warning." 498 | ) 499 | deprecate( 500 | "len(prompt) != len(image)", 501 | "1.0.0", 502 | deprecation_message, 503 | standard_warn=False, 504 | ) 505 | additional_image_per_prompt = batch_size // image_latents.shape[0] 506 | image_latents = torch.cat( 507 | [image_latents] * additional_image_per_prompt, dim=0 508 | ) 509 | elif ( 510 | batch_size > image_latents.shape[0] 511 | and batch_size % image_latents.shape[0] != 0 512 | ): 513 | raise ValueError( 514 | f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." 515 | ) 516 | else: 517 | image_latents = torch.cat([image_latents], dim=0) 518 | 519 | if do_classifier_free_guidance: 520 | uncond_image_latents = torch.zeros_like(image_latents) 521 | image_latents = torch.cat( 522 | [image_latents, image_latents, uncond_image_latents], dim=0 523 | ) 524 | 525 | return image_latents 526 | 527 | @torch.no_grad() 528 | def __call__( 529 | self, 530 | height: int, 531 | width: int, 532 | prompt: Union[str, List[str]] = None, 533 | albedo: Optional[ 534 | Union[ 535 | torch.FloatTensor, 536 | PIL.Image.Image, 537 | np.ndarray, 538 | List[torch.FloatTensor], 539 | List[PIL.Image.Image], 540 | List[np.ndarray], 541 | ] 542 | ] = None, 543 | normal: Optional[ 544 | Union[ 545 | torch.FloatTensor, 546 | PIL.Image.Image, 547 | np.ndarray, 548 | List[torch.FloatTensor], 549 | List[PIL.Image.Image], 550 | List[np.ndarray], 551 | ] 552 | ] = None, 553 | roughness: Optional[ 554 | Union[ 555 | torch.FloatTensor, 556 | PIL.Image.Image, 557 | np.ndarray, 558 | List[torch.FloatTensor], 559 | List[PIL.Image.Image], 560 | List[np.ndarray], 561 | ] 562 | ] = None, 563 | metallic: Optional[ 564 | Union[ 565 | torch.FloatTensor, 566 | PIL.Image.Image, 567 | np.ndarray, 568 | List[torch.FloatTensor], 569 | List[PIL.Image.Image], 570 | List[np.ndarray], 571 | ] 572 | ] = None, 573 | irradiance: Optional[ 574 | Union[ 575 | torch.FloatTensor, 576 | PIL.Image.Image, 577 | np.ndarray, 578 | List[torch.FloatTensor], 579 | List[PIL.Image.Image], 580 | List[np.ndarray], 581 | ] 582 | ] = None, 583 | mask: Optional[ 584 | Union[ 585 | torch.FloatTensor, 586 | PIL.Image.Image, 587 | np.ndarray, 588 | List[torch.FloatTensor], 589 | List[PIL.Image.Image], 590 | List[np.ndarray], 591 | ] 592 | ] = None, 593 | masked_image: Optional[ 594 | Union[ 595 | torch.FloatTensor, 596 | PIL.Image.Image, 597 | np.ndarray, 598 | List[torch.FloatTensor], 599 | List[PIL.Image.Image], 600 | List[np.ndarray], 601 | ] 602 | ] = None, 603 | photo: Optional[ 604 | Union[ 605 | torch.FloatTensor, 606 | PIL.Image.Image, 607 | np.ndarray, 608 | List[torch.FloatTensor], 609 | List[PIL.Image.Image], 610 | List[np.ndarray], 611 | ] 612 | ] = None, 613 | guidance_scale: float = 0, 614 | image_guidance_scale: float = 0, 615 | guidance_rescale: float = 0.0, 616 | num_inference_steps: int = 100, 617 | required_aovs: List[str] = ["albedo"], 618 | return_predicted_x0s: bool = False, 619 | negative_prompt: Optional[Union[str, List[str]]] = None, 620 | num_images_per_prompt: Optional[int] = 1, 621 | eta: float = 0.0, 622 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 623 | latents: Optional[torch.FloatTensor] = None, 624 | prompt_embeds: Optional[torch.FloatTensor] = None, 625 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 626 | output_type: Optional[str] = "pil", 627 | return_dict: bool = True, 628 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 629 | callback_steps: int = 1, 630 | ): 631 | r""" 632 | The call function to the pipeline for generation. 633 | 634 | Args: 635 | prompt (`str` or `List[str]`, *optional*): 636 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 637 | image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 638 | `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept 639 | image latents as `image`, but if passing latents directly it is not encoded again. 640 | num_inference_steps (`int`, *optional*, defaults to 100): 641 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 642 | expense of slower inference. 643 | guidance_scale (`float`, *optional*, defaults to 7.5): 644 | A higher guidance scale value encourages the model to generate images closely linked to the text 645 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 646 | image_guidance_scale (`float`, *optional*, defaults to 1.5): 647 | Push the generated image towards the inital `image`. Image guidance scale is enabled by setting 648 | `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely 649 | linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a 650 | value of at least `1`. 651 | negative_prompt (`str` or `List[str]`, *optional*): 652 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 653 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 654 | num_images_per_prompt (`int`, *optional*, defaults to 1): 655 | The number of images to generate per prompt. 656 | eta (`float`, *optional*, defaults to 0.0): 657 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 658 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 659 | generator (`torch.Generator`, *optional*): 660 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 661 | generation deterministic. 662 | latents (`torch.FloatTensor`, *optional*): 663 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 664 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 665 | tensor is generated by sampling using the supplied random `generator`. 666 | prompt_embeds (`torch.FloatTensor`, *optional*): 667 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 668 | provided, text embeddings are generated from the `prompt` input argument. 669 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 670 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 671 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 672 | output_type (`str`, *optional*, defaults to `"pil"`): 673 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 674 | return_dict (`bool`, *optional*, defaults to `True`): 675 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 676 | plain tuple. 677 | callback (`Callable`, *optional*): 678 | A function that calls every `callback_steps` steps during inference. The function is called with the 679 | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 680 | callback_steps (`int`, *optional*, defaults to 1): 681 | The frequency at which the `callback` function is called. If not specified, the callback is called at 682 | every step. 683 | 684 | Examples: 685 | 686 | ```py 687 | >>> import PIL 688 | >>> import requests 689 | >>> import torch 690 | >>> from io import BytesIO 691 | 692 | >>> from diffusers import StableDiffusionInstructPix2PixPipeline 693 | 694 | 695 | >>> def download_image(url): 696 | ... response = requests.get(url) 697 | ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") 698 | 699 | 700 | >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" 701 | 702 | >>> image = download_image(img_url).resize((512, 512)) 703 | 704 | >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( 705 | ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 706 | ... ) 707 | >>> pipe = pipe.to("cuda") 708 | 709 | >>> prompt = "make the mountains snowy" 710 | >>> image = pipe(prompt=prompt, image=image).images[0] 711 | ``` 712 | 713 | Returns: 714 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 715 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 716 | otherwise a `tuple` is returned where the first element is a list with the generated images and the 717 | second element is a list of `bool`s indicating whether the corresponding generated image contains 718 | "not-safe-for-work" (nsfw) content. 719 | """ 720 | # 0. Check inputs 721 | self.check_inputs( 722 | prompt, 723 | callback_steps, 724 | negative_prompt, 725 | prompt_embeds, 726 | negative_prompt_embeds, 727 | ) 728 | 729 | # 1. Define call parameters 730 | if prompt is not None and isinstance(prompt, str): 731 | batch_size = 1 732 | elif prompt is not None and isinstance(prompt, list): 733 | batch_size = len(prompt) 734 | else: 735 | batch_size = prompt_embeds.shape[0] 736 | 737 | device = self._execution_device 738 | do_classifier_free_guidance = ( 739 | guidance_scale >= 1.0 and image_guidance_scale >= 1.0 740 | ) 741 | # check if scheduler is in sigmas space 742 | scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") 743 | 744 | # 2. Encode input prompt 745 | prompt_embeds = self._encode_prompt( 746 | prompt, 747 | device, 748 | num_images_per_prompt, 749 | do_classifier_free_guidance, 750 | negative_prompt, 751 | prompt_embeds=prompt_embeds, 752 | negative_prompt_embeds=negative_prompt_embeds, 753 | ) 754 | 755 | # 3. Preprocess image 756 | # For normal, the preprocessing does nothing 757 | # For others, the preprocessing remap the values to [-1, 1] 758 | preprocessed_aovs = {} 759 | 760 | assert mask is not None 761 | masked_image = self.image_processor.preprocess(masked_image) 762 | photo = self.image_processor.preprocess(photo) 763 | 764 | for aov_name in required_aovs: 765 | if aov_name == "albedo": 766 | if albedo is not None: 767 | preprocessed_aovs[aov_name] = self.image_processor.preprocess( 768 | albedo * (1 - mask) 769 | ) 770 | else: 771 | preprocessed_aovs[aov_name] = None 772 | 773 | if aov_name == "normal": 774 | if normal is not None: 775 | preprocessed_aovs[aov_name] = ( 776 | self.image_processor.preprocess_normal(normal * (1 - mask)) 777 | ) 778 | else: 779 | preprocessed_aovs[aov_name] = None 780 | 781 | if aov_name == "roughness": 782 | if roughness is not None: 783 | preprocessed_aovs[aov_name] = self.image_processor.preprocess( 784 | roughness * (1 - mask) 785 | ) 786 | else: 787 | preprocessed_aovs[aov_name] = None 788 | if aov_name == "metallic": 789 | if metallic is not None: 790 | preprocessed_aovs[aov_name] = self.image_processor.preprocess( 791 | metallic * (1 - mask) 792 | ) 793 | else: 794 | preprocessed_aovs[aov_name] = None 795 | if aov_name == "irradiance": 796 | if irradiance is not None: 797 | preprocessed_aovs[aov_name] = self.image_processor.preprocess( 798 | irradiance * (1 - mask) 799 | ) 800 | else: 801 | preprocessed_aovs[aov_name] = None 802 | 803 | # 4. set timesteps 804 | self.scheduler.set_timesteps(num_inference_steps, device=device) 805 | timesteps = self.scheduler.timesteps 806 | 807 | # 5. Prepare latent variables 808 | num_channels_latents = self.vae.config.latent_channels 809 | latents = self.prepare_latents( 810 | batch_size * num_images_per_prompt, 811 | num_channels_latents, 812 | height, 813 | width, 814 | prompt_embeds.dtype, 815 | device, 816 | generator, 817 | latents, 818 | ) 819 | 820 | height_latent, width_latent = latents.shape[-2:] 821 | 822 | # 6. Prepare Image latents 823 | image_latents = [] 824 | scaling_factors = { 825 | "albedo": 0.17301377137652138, 826 | "normal": 0.17483895473058078, 827 | "roughness": 0.1680724853626448, 828 | "metallic": 0.13135013390855135, 829 | } 830 | for aov_name, aov in preprocessed_aovs.items(): 831 | if aov is None: 832 | image_latent = torch.zeros( 833 | batch_size, 834 | num_channels_latents, 835 | height_latent, 836 | width_latent, 837 | dtype=prompt_embeds.dtype, 838 | device=device, 839 | ) 840 | if aov_name == "irradiance": 841 | image_latent = image_latent[:, 0:3] 842 | if do_classifier_free_guidance: 843 | image_latents.append( 844 | torch.cat([image_latent, image_latent, image_latent], dim=0) 845 | ) 846 | else: 847 | image_latents.append(image_latent) 848 | else: 849 | if aov_name == "irradiance": 850 | image_latent = F.interpolate( 851 | aov.to(device=device, dtype=prompt_embeds.dtype), 852 | size=(height_latent, width_latent), 853 | mode="bilinear", 854 | align_corners=False, 855 | antialias=True, 856 | ) 857 | if do_classifier_free_guidance: 858 | uncond_image_latent = torch.zeros_like(image_latent) 859 | image_latent = torch.cat( 860 | [image_latent, image_latent, uncond_image_latent], dim=0 861 | ) 862 | else: 863 | scaling_factor = scaling_factors[aov_name] 864 | image_latent = ( 865 | self.prepare_image_latents( 866 | aov, 867 | batch_size, 868 | num_images_per_prompt, 869 | prompt_embeds.dtype, 870 | device, 871 | do_classifier_free_guidance, 872 | generator, 873 | ) 874 | * scaling_factor 875 | ) 876 | image_latents.append(image_latent) 877 | 878 | masked_image_latents = ( 879 | self.prepare_image_latents( 880 | masked_image, 881 | batch_size, 882 | num_images_per_prompt, 883 | prompt_embeds.dtype, 884 | device, 885 | do_classifier_free_guidance, 886 | generator, 887 | ) 888 | * self.vae.config.scaling_factor 889 | ) 890 | photo_latents = ( 891 | self.prepare_image_latents( 892 | photo, 893 | batch_size, 894 | num_images_per_prompt, 895 | prompt_embeds.dtype, 896 | device, 897 | do_classifier_free_guidance, 898 | generator, 899 | ) 900 | * self.vae.config.scaling_factor 901 | ) 902 | mask_latents = F.interpolate( 903 | mask.to(device=device, dtype=prompt_embeds.dtype).unsqueeze(0), 904 | size=(height_latent, width_latent), 905 | mode="bilinear", 906 | align_corners=False, 907 | antialias=True, 908 | ) 909 | if do_classifier_free_guidance: 910 | uncond_mask_latent = torch.zeros_like(mask_latents) 911 | mask_latents = torch.cat( 912 | [mask_latents, mask_latents, uncond_mask_latent], dim=0 913 | ) 914 | image_latents.append(masked_image_latents) 915 | image_latents.append(mask_latents) 916 | 917 | image_latents = torch.cat(image_latents, dim=1) 918 | 919 | # 7. Check that shapes of latents and image match the UNet channels 920 | num_channels_image = image_latents.shape[1] 921 | if num_channels_latents + num_channels_image != self.unet.config.in_channels: 922 | raise ValueError( 923 | f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" 924 | f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" 925 | f" `num_channels_image`: {num_channels_image} " 926 | f" = {num_channels_latents+num_channels_image}. Please verify the config of" 927 | " `pipeline.unet` or your `image` input." 928 | ) 929 | 930 | # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 931 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 932 | 933 | predicted_x0s = [] 934 | 935 | # 9. Denoising loop 936 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 937 | with self.progress_bar(total=num_inference_steps) as progress_bar: 938 | for i, t in enumerate(timesteps): 939 | # Expand the latents if we are doing classifier free guidance. 940 | # The latents are expanded 3 times because for pix2pix the guidance\ 941 | # is applied for both the text and the input image. 942 | latent_model_input = ( 943 | torch.cat([latents] * 3) if do_classifier_free_guidance else latents 944 | ) 945 | 946 | # concat latents, image_latents in the channel dimension 947 | scaled_latent_model_input = self.scheduler.scale_model_input( 948 | latent_model_input, t 949 | ) 950 | scaled_latent_model_input = torch.cat( 951 | [scaled_latent_model_input, image_latents], dim=1 952 | ) 953 | 954 | # predict the noise residual 955 | noise_pred = self.unet( 956 | scaled_latent_model_input, 957 | t, 958 | encoder_hidden_states=prompt_embeds, 959 | return_dict=False, 960 | )[0] 961 | 962 | # perform guidance 963 | if do_classifier_free_guidance: 964 | ( 965 | noise_pred_text, 966 | noise_pred_image, 967 | noise_pred_uncond, 968 | ) = noise_pred.chunk(3) 969 | noise_pred = ( 970 | noise_pred_uncond 971 | + guidance_scale * (noise_pred_text - noise_pred_image) 972 | + image_guidance_scale * (noise_pred_image - noise_pred_uncond) 973 | ) 974 | 975 | if do_classifier_free_guidance and guidance_rescale > 0.0: 976 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 977 | noise_pred = rescale_noise_cfg( 978 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale 979 | ) 980 | 981 | # compute the previous noisy sample x_t -> x_t-1 982 | output = self.scheduler.step( 983 | noise_pred, t, latents, **extra_step_kwargs, return_dict=True 984 | ) 985 | 986 | latents = output[0] 987 | 988 | if return_predicted_x0s: 989 | predicted_x0s.append(output[1]) 990 | 991 | # call the callback, if provided 992 | if i == len(timesteps) - 1 or ( 993 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 994 | ): 995 | progress_bar.update() 996 | if callback is not None and i % callback_steps == 0: 997 | callback(i, t, latents) 998 | 999 | if not output_type == "latent": 1000 | image = self.vae.decode( 1001 | latents / self.vae.config.scaling_factor, return_dict=False 1002 | )[0] 1003 | 1004 | if return_predicted_x0s: 1005 | predicted_x0_images = [ 1006 | self.vae.decode( 1007 | predicted_x0 / self.vae.config.scaling_factor, return_dict=False 1008 | )[0] 1009 | for predicted_x0 in predicted_x0s 1010 | ] 1011 | else: 1012 | image = latents 1013 | predicted_x0_images = predicted_x0s 1014 | 1015 | do_denormalize = [True] * image.shape[0] 1016 | 1017 | image = self.image_processor.postprocess( 1018 | image, output_type=output_type, do_denormalize=do_denormalize 1019 | ) 1020 | masked_image = self.vae.decode( 1021 | masked_image_latents[0:1] / self.vae.config.scaling_factor, 1022 | return_dict=False, 1023 | )[0] 1024 | masked_image = self.image_processor.postprocess( 1025 | masked_image, output_type=output_type, do_denormalize=do_denormalize 1026 | ) 1027 | 1028 | photo = self.vae.decode( 1029 | photo_latents[0:1] / self.vae.config.scaling_factor, return_dict=False 1030 | )[0] 1031 | photo = self.image_processor.postprocess( 1032 | photo, output_type=output_type, do_denormalize=do_denormalize 1033 | ) 1034 | image = [image, masked_image, photo] 1035 | 1036 | if return_predicted_x0s: 1037 | predicted_x0_images = [ 1038 | self.image_processor.postprocess( 1039 | predicted_x0_image, 1040 | output_type=output_type, 1041 | do_denormalize=do_denormalize, 1042 | ) 1043 | for predicted_x0_image in predicted_x0_images 1044 | ] 1045 | 1046 | # Offload last model to CPU 1047 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 1048 | self.final_offload_hook.offload() 1049 | 1050 | if not return_dict: 1051 | return image 1052 | 1053 | if return_predicted_x0s: 1054 | return StableDiffusionAOVPipelineOutput( 1055 | images=image, predicted_x0_images=predicted_x0_images 1056 | ) 1057 | else: 1058 | return StableDiffusionAOVPipelineOutput(images=image) 1059 | --------------------------------------------------------------------------------