├── .gitignore ├── DAI ├── __init__.py ├── controlnetvae.py ├── decoder.py ├── pipeline_all.py └── pipeline_onestep.py ├── README.md ├── assets ├── logo.png ├── logo_old.png ├── teaser.mp4 └── teaser.png ├── demo.py ├── files └── image │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ └── 8.png ├── input ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png └── 8.png ├── requirements.txt ├── run.py ├── run.sh └── utils ├── image_utils.py └── loss_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | weights/ -------------------------------------------------------------------------------- /DAI/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/DAI/__init__.py -------------------------------------------------------------------------------- /DAI/controlnetvae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders.single_file_model import FromOriginalModelMixin 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.models.attention_processor import ( 25 | ADDED_KV_ATTENTION_PROCESSORS, 26 | CROSS_ATTENTION_PROCESSORS, 27 | AttentionProcessor, 28 | AttnAddedKVProcessor, 29 | AttnProcessor, 30 | ) 31 | from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps 32 | from diffusers.models.modeling_utils import ModelMixin 33 | from diffusers.models.unets.unet_2d_blocks import ( 34 | CrossAttnDownBlock2D, 35 | DownBlock2D, 36 | UNetMidBlock2D, 37 | UNetMidBlock2DCrossAttn, 38 | get_down_block, 39 | ) 40 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 41 | from diffusers.models.controlnet import ControlNetOutput 42 | from diffusers.models import ControlNetModel 43 | 44 | import pdb 45 | 46 | 47 | class ControlNetVAEModel(ControlNetModel): 48 | def forward( 49 | self, 50 | sample: torch.Tensor, 51 | timestep: Union[torch.Tensor, float, int], 52 | encoder_hidden_states: torch.Tensor, 53 | controlnet_cond: torch.Tensor = None, 54 | conditioning_scale: float = 1.0, 55 | class_labels: Optional[torch.Tensor] = None, 56 | timestep_cond: Optional[torch.Tensor] = None, 57 | attention_mask: Optional[torch.Tensor] = None, 58 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 59 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 60 | guess_mode: bool = False, 61 | return_dict: bool = True, 62 | ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: 63 | """ 64 | The [`ControlNetVAEModel`] forward method. 65 | 66 | Args: 67 | sample (`torch.Tensor`): 68 | The noisy input tensor. 69 | timestep (`Union[torch.Tensor, float, int]`): 70 | The number of timesteps to denoise an input. 71 | encoder_hidden_states (`torch.Tensor`): 72 | The encoder hidden states. 73 | controlnet_cond (`torch.Tensor`): 74 | The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. 75 | conditioning_scale (`float`, defaults to `1.0`): 76 | The scale factor for ControlNet outputs. 77 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 78 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 79 | timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): 80 | Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the 81 | timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep 82 | embeddings. 83 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`): 84 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 85 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 86 | negative values to the attention scores corresponding to "discard" tokens. 87 | added_cond_kwargs (`dict`): 88 | Additional conditions for the Stable Diffusion XL UNet. 89 | cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): 90 | A kwargs dictionary that if specified is passed along to the `AttnProcessor`. 91 | guess_mode (`bool`, defaults to `False`): 92 | In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if 93 | you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. 94 | return_dict (`bool`, defaults to `True`): 95 | Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. 96 | 97 | Returns: 98 | [`~models.controlnet.ControlNetOutput`] **or** `tuple`: 99 | If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is 100 | returned where the first element is the sample tensor. 101 | """ 102 | # check channel order 103 | 104 | 105 | channel_order = self.config.controlnet_conditioning_channel_order 106 | 107 | if channel_order == "rgb": 108 | # in rgb order by default 109 | ... 110 | elif channel_order == "bgr": 111 | controlnet_cond = torch.flip(controlnet_cond, dims=[1]) 112 | else: 113 | raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") 114 | 115 | # prepare attention_mask 116 | if attention_mask is not None: 117 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 118 | attention_mask = attention_mask.unsqueeze(1) 119 | 120 | # 1. time 121 | timesteps = timestep 122 | if not torch.is_tensor(timesteps): 123 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 124 | # This would be a good case for the `match` statement (Python 3.10+) 125 | is_mps = sample.device.type == "mps" 126 | if isinstance(timestep, float): 127 | dtype = torch.float32 if is_mps else torch.float64 128 | else: 129 | dtype = torch.int32 if is_mps else torch.int64 130 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 131 | elif len(timesteps.shape) == 0: 132 | timesteps = timesteps[None].to(sample.device) 133 | 134 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 135 | timesteps = timesteps.expand(sample.shape[0]) 136 | 137 | t_emb = self.time_proj(timesteps) 138 | 139 | # timesteps does not contain any weights and will always return f32 tensors 140 | # but time_embedding might actually be running in fp16. so we need to cast here. 141 | # there might be better ways to encapsulate this. 142 | t_emb = t_emb.to(dtype=sample.dtype) 143 | 144 | emb = self.time_embedding(t_emb, timestep_cond) 145 | aug_emb = None 146 | 147 | if self.class_embedding is not None: 148 | if class_labels is None: 149 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 150 | 151 | if self.config.class_embed_type == "timestep": 152 | class_labels = self.time_proj(class_labels) 153 | 154 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 155 | emb = emb + class_emb 156 | 157 | if self.config.addition_embed_type is not None: 158 | if self.config.addition_embed_type == "text": 159 | aug_emb = self.add_embedding(encoder_hidden_states) 160 | 161 | elif self.config.addition_embed_type == "text_time": 162 | if "text_embeds" not in added_cond_kwargs: 163 | raise ValueError( 164 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 165 | ) 166 | text_embeds = added_cond_kwargs.get("text_embeds") 167 | if "time_ids" not in added_cond_kwargs: 168 | raise ValueError( 169 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 170 | ) 171 | time_ids = added_cond_kwargs.get("time_ids") 172 | time_embeds = self.add_time_proj(time_ids.flatten()) 173 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 174 | 175 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 176 | add_embeds = add_embeds.to(emb.dtype) 177 | aug_emb = self.add_embedding(add_embeds) 178 | 179 | 180 | emb = emb + aug_emb if aug_emb is not None else emb 181 | # 2. pre-process 182 | sample = self.conv_in(sample) 183 | 184 | # 3. down 185 | down_block_res_samples = (sample,) 186 | for downsample_block in self.down_blocks: 187 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 188 | sample, res_samples = downsample_block( 189 | hidden_states=sample, 190 | temb=emb, 191 | encoder_hidden_states=encoder_hidden_states, 192 | attention_mask=attention_mask, 193 | cross_attention_kwargs=cross_attention_kwargs, 194 | ) 195 | else: 196 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 197 | 198 | down_block_res_samples += res_samples 199 | 200 | # 4. mid 201 | if self.mid_block is not None: 202 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: 203 | sample = self.mid_block( 204 | sample, 205 | emb, 206 | encoder_hidden_states=encoder_hidden_states, 207 | attention_mask=attention_mask, 208 | cross_attention_kwargs=cross_attention_kwargs, 209 | ) 210 | else: 211 | sample = self.mid_block(sample, emb) 212 | 213 | # 5. Control net blocks 214 | 215 | controlnet_down_block_res_samples = () 216 | 217 | # NOTE that controlnet downblock is zeroconv, we discard 218 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 219 | down_block_res_sample = down_block_res_sample 220 | controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) 221 | 222 | down_block_res_samples = controlnet_down_block_res_samples 223 | 224 | mid_block_res_sample = sample 225 | 226 | # 6. scaling 227 | if guess_mode and not self.config.global_pool_conditions: 228 | scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 229 | scales = scales * conditioning_scale 230 | down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] 231 | mid_block_res_sample = mid_block_res_sample * scales[-1] # last one 232 | else: 233 | down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] 234 | mid_block_res_sample = mid_block_res_sample * conditioning_scale 235 | 236 | if self.config.global_pool_conditions: 237 | down_block_res_samples = [ 238 | torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples 239 | ] 240 | mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) 241 | 242 | if not return_dict: 243 | return (down_block_res_samples, mid_block_res_sample) 244 | 245 | return ControlNetOutput( 246 | down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample 247 | ) 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /DAI/decoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from typing import Dict, Optional, Tuple, Union 5 | from diffusers import AutoencoderKL 6 | from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, Decoder 7 | from diffusers.models.attention_processor import Attention, AttentionProcessor 8 | from diffusers.models.modeling_outputs import AutoencoderKLOutput 9 | from diffusers.models.unets.unet_2d_blocks import ( 10 | AutoencoderTinyBlock, 11 | UNetMidBlock2D, 12 | get_down_block, 13 | get_up_block, 14 | ) 15 | from diffusers.utils.accelerate_utils import apply_forward_hook 16 | 17 | class ZeroConv2d(nn.Module): 18 | """ 19 | Zero Convolution layer, similar to the one used in ControlNet. 20 | """ 21 | def __init__(self, in_channels, out_channels): 22 | super().__init__() 23 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 24 | self.conv.weight.data.zero_() 25 | self.conv.bias.data.zero_() 26 | 27 | def forward(self, x): 28 | return self.conv(x) 29 | 30 | class CustomAutoencoderKL(AutoencoderKL): 31 | def __init__( 32 | self, 33 | in_channels: int = 3, 34 | out_channels: int = 3, 35 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), 36 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), 37 | block_out_channels: Tuple[int] = (64,), 38 | layers_per_block: int = 1, 39 | act_fn: str = "silu", 40 | latent_channels: int = 4, 41 | norm_num_groups: int = 32, 42 | sample_size: int = 32, 43 | scaling_factor: float = 0.18215, 44 | force_upcast: float = True, 45 | use_quant_conv: bool = True, 46 | use_post_quant_conv: bool = True, 47 | mid_block_add_attention: bool = True, 48 | ): 49 | super().__init__( 50 | in_channels=in_channels, 51 | out_channels=out_channels, 52 | down_block_types=down_block_types, 53 | up_block_types=up_block_types, 54 | block_out_channels=block_out_channels, 55 | layers_per_block=layers_per_block, 56 | act_fn=act_fn, 57 | latent_channels=latent_channels, 58 | norm_num_groups=norm_num_groups, 59 | sample_size=sample_size, 60 | scaling_factor=scaling_factor, 61 | force_upcast=force_upcast, 62 | use_quant_conv=use_quant_conv, 63 | use_post_quant_conv=use_post_quant_conv, 64 | mid_block_add_attention=mid_block_add_attention, 65 | ) 66 | 67 | # Add Zero Convolution layers to the encoder 68 | # self.zero_convs = nn.ModuleList() 69 | # for i, out_channels_ in enumerate(block_out_channels): 70 | # self.zero_convs.append(ZeroConv2d(out_channels_, out_channels_)) 71 | 72 | # Modify the decoder to accept skip connections 73 | self.decoder = CustomDecoder( 74 | in_channels=latent_channels, 75 | out_channels=out_channels, 76 | up_block_types=up_block_types, 77 | block_out_channels=block_out_channels, 78 | layers_per_block=layers_per_block, 79 | norm_num_groups=norm_num_groups, 80 | act_fn=act_fn, 81 | mid_block_add_attention=mid_block_add_attention, 82 | ) 83 | self.encoder = CustomEncoder( 84 | in_channels=in_channels, 85 | out_channels=latent_channels, 86 | down_block_types=down_block_types, 87 | block_out_channels=block_out_channels, 88 | layers_per_block=layers_per_block, 89 | norm_num_groups=norm_num_groups, 90 | act_fn=act_fn, 91 | mid_block_add_attention=mid_block_add_attention, 92 | ) 93 | 94 | def encode(self, x: torch.Tensor, return_dict: bool = True): 95 | # Get the encoder outputs 96 | _, skip_connections = self.encoder(x) 97 | 98 | return skip_connections 99 | 100 | def decode(self, z: torch.Tensor, skip_connections: list, return_dict: bool = True): 101 | if self.post_quant_conv is not None: 102 | z = self.post_quant_conv(z) 103 | # Decode the latent representation with skip connections 104 | dec = self.decoder(z, skip_connections) 105 | 106 | if not return_dict: 107 | return (dec,) 108 | 109 | return DecoderOutput(sample=dec) 110 | 111 | def forward( 112 | self, 113 | sample: torch.Tensor, 114 | sample_posterior: bool = False, 115 | return_dict: bool = True, 116 | generator: Optional[torch.Generator] = None, 117 | ): 118 | # Encode the input and get the skip connections 119 | posterior, skip_connections = self.encode(sample, return_dict=True) 120 | 121 | # Sample from the posterior 122 | if sample_posterior: 123 | z = posterior.sample(generator=generator) 124 | else: 125 | z = posterior.mode() 126 | 127 | # Decode the latent representation with skip connections 128 | dec = self.decode(z, skip_connections, return_dict=return_dict) 129 | 130 | if not return_dict: 131 | return (dec,) 132 | 133 | return DecoderOutput(sample=dec) 134 | 135 | 136 | class CustomDecoder(Decoder): 137 | def __init__( 138 | self, 139 | in_channels: int, 140 | out_channels: int, 141 | up_block_types: Tuple[str, ...], 142 | block_out_channels: Tuple[int, ...], 143 | layers_per_block: int, 144 | norm_num_groups: int, 145 | act_fn: str, 146 | mid_block_add_attention: bool, 147 | ): 148 | super().__init__( 149 | in_channels=in_channels, 150 | out_channels=out_channels, 151 | up_block_types=up_block_types, 152 | block_out_channels=block_out_channels, 153 | layers_per_block=layers_per_block, 154 | norm_num_groups=norm_num_groups, 155 | act_fn=act_fn, 156 | mid_block_add_attention=mid_block_add_attention, 157 | ) 158 | 159 | def forward( 160 | self, 161 | sample: torch.Tensor, 162 | skip_connections: list, 163 | latent_embeds: Optional[torch.Tensor] = None, 164 | ) -> torch.Tensor: 165 | r"""The forward method of the `Decoder` class.""" 166 | 167 | sample = self.conv_in(sample) 168 | 169 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype 170 | if torch.is_grad_enabled() and self.gradient_checkpointing: 171 | 172 | def create_custom_forward(module): 173 | def custom_forward(*inputs): 174 | return module(*inputs) 175 | 176 | return custom_forward 177 | 178 | if is_torch_version(">=", "1.11.0"): 179 | # middle 180 | sample = torch.utils.checkpoint.checkpoint( 181 | create_custom_forward(self.mid_block), 182 | sample, 183 | latent_embeds, 184 | use_reentrant=False, 185 | ) 186 | sample = sample.to(upscale_dtype) 187 | 188 | # up 189 | for up_block in self.up_blocks: 190 | sample = torch.utils.checkpoint.checkpoint( 191 | create_custom_forward(up_block), 192 | sample, 193 | latent_embeds, 194 | use_reentrant=False, 195 | ) 196 | else: 197 | # middle 198 | sample = torch.utils.checkpoint.checkpoint( 199 | create_custom_forward(self.mid_block), sample, latent_embeds 200 | ) 201 | sample = sample.to(upscale_dtype) 202 | 203 | # up 204 | for up_block in self.up_blocks: 205 | sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) 206 | else: 207 | # middle 208 | sample = self.mid_block(sample, latent_embeds) 209 | sample = sample.to(upscale_dtype) 210 | 211 | # up 212 | # for up_block in self.up_blocks: 213 | # sample = up_block(sample, latent_embeds) 214 | for i, up_block in enumerate(self.up_blocks): 215 | # Add skip connections directly 216 | if i < len(skip_connections): 217 | skip_connection = skip_connections[-(i + 1)] 218 | # import pdb; pdb.set_trace() 219 | sample = sample + skip_connection 220 | # import pdb; pdb.set_trace() #torch.Size([1, 512, 96, 96] 221 | sample = up_block(sample) 222 | 223 | # post-process 224 | if latent_embeds is None: 225 | sample = self.conv_norm_out(sample) 226 | else: 227 | sample = self.conv_norm_out(sample, latent_embeds) 228 | sample = self.conv_act(sample) 229 | sample = self.conv_out(sample) 230 | 231 | return sample 232 | 233 | class CustomEncoder(Encoder): 234 | r""" 235 | Custom Encoder that adds Zero Convolution layers to each block's output 236 | to generate skip connections. 237 | """ 238 | def __init__( 239 | self, 240 | in_channels: int = 3, 241 | out_channels: int = 3, 242 | down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), 243 | block_out_channels: Tuple[int, ...] = (64,), 244 | layers_per_block: int = 2, 245 | norm_num_groups: int = 32, 246 | act_fn: str = "silu", 247 | double_z: bool = True, 248 | mid_block_add_attention: bool = True, 249 | ): 250 | super().__init__( 251 | in_channels=in_channels, 252 | out_channels=out_channels, 253 | down_block_types=down_block_types, 254 | block_out_channels=block_out_channels, 255 | layers_per_block=layers_per_block, 256 | norm_num_groups=norm_num_groups, 257 | act_fn=act_fn, 258 | double_z=double_z, 259 | mid_block_add_attention=mid_block_add_attention, 260 | ) 261 | 262 | # Add Zero Convolution layers to each block's output 263 | self.zero_convs = nn.ModuleList() 264 | for i, out_channels in enumerate(block_out_channels): 265 | if i < 2: 266 | self.zero_convs.append(ZeroConv2d(out_channels, out_channels * 2)) 267 | else: 268 | self.zero_convs.append(ZeroConv2d(out_channels, out_channels)) 269 | 270 | def forward(self, sample: torch.Tensor) -> list[torch.Tensor]: 271 | r""" 272 | Forward pass of the CustomEncoder. 273 | 274 | Args: 275 | sample (`torch.Tensor`): Input tensor. 276 | 277 | Returns: 278 | `Tuple[torch.Tensor, List[torch.Tensor]]`: 279 | - The final latent representation. 280 | - A list of skip connections from each block. 281 | """ 282 | skip_connections = [] 283 | 284 | # Initial convolution 285 | sample = self.conv_in(sample) 286 | 287 | # Down blocks 288 | for i, (down_block, zero_conv) in enumerate(zip(self.down_blocks, self.zero_convs)): 289 | # import pdb; pdb.set_trace() 290 | sample = down_block(sample) 291 | if i != len(self.down_blocks) - 1: 292 | sample_out = nn.functional.interpolate(zero_conv(sample), scale_factor=2, mode='bilinear', align_corners=False) 293 | else: 294 | sample_out = zero_conv(sample) 295 | skip_connections.append(sample_out) 296 | 297 | 298 | # import pdb; pdb.set_trace() 299 | # torch.Size([1, 128, 768, 768]) 300 | # torch.Size([1, 128, 384, 384]) 301 | # torch.Size([1, 256, 192, 192]) 302 | # torch.Size([1, 512, 96, 96]) 303 | # torch.Size([1, 512, 96, 96]) 304 | 305 | # # Middle block 306 | # sample = self.mid_block(sample) 307 | 308 | # # Post-process 309 | # sample = self.conv_norm_out(sample) 310 | # sample = self.conv_act(sample) 311 | # sample = self.conv_out(sample) 312 | 313 | return sample, skip_connections -------------------------------------------------------------------------------- /DAI/pipeline_all.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # -------------------------------------------------------------------------- 16 | # More information and citation instructions are available on the 17 | # -------------------------------------------------------------------------- 18 | from dataclasses import dataclass 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | from PIL import Image 24 | from tqdm.auto import tqdm 25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 26 | 27 | 28 | from diffusers.image_processor import PipelineImageInput 29 | from diffusers.models import ( 30 | AutoencoderKL, 31 | UNet2DConditionModel, 32 | ControlNetModel, 33 | ) 34 | from diffusers.schedulers import ( 35 | DDIMScheduler 36 | ) 37 | 38 | from diffusers.utils import ( 39 | BaseOutput, 40 | logging, 41 | replace_example_docstring, 42 | ) 43 | 44 | 45 | from diffusers.utils.torch_utils import randn_tensor 46 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline 47 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor 48 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 49 | 50 | from DAI.decoder import CustomAutoencoderKL 51 | 52 | import pdb 53 | 54 | 55 | 56 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 57 | 58 | 59 | EXAMPLE_DOC_STRING = """ 60 | Examples: 61 | ```py 62 | >>> import diffusers 63 | >>> import torch 64 | 65 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( 66 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 67 | ... ).to("cuda") 68 | 69 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") 70 | >>> normals = pipe(image) 71 | 72 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction) 73 | >>> vis[0].save("einstein_normals.png") 74 | ``` 75 | """ 76 | 77 | 78 | @dataclass 79 | class DAIOutput(BaseOutput): 80 | """ 81 | Output class for Marigold monocular normals prediction pipeline. 82 | 83 | Args: 84 | prediction (`np.ndarray`, `torch.Tensor`): 85 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height 86 | \times width$, regardless of whether the images were passed as a 4D array or a list. 87 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`): 88 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages 89 | \times 1 \times height \times width$. 90 | latent (`None`, `torch.Tensor`): 91 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. 92 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. 93 | """ 94 | 95 | prediction: Union[np.ndarray, torch.Tensor] 96 | latent: Union[None, torch.Tensor] 97 | gaus_noise: Union[None, torch.Tensor] 98 | 99 | 100 | class DAIPipeline(StableDiffusionControlNetPipeline): 101 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. 102 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. 103 | 104 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 105 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 106 | 107 | The pipeline also inherits the following loading methods: 108 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 109 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 110 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 111 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files 112 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 113 | 114 | Args: 115 | vae ([`AutoencoderKL`]): 116 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 117 | text_encoder ([`~transformers.CLIPTextModel`]): 118 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 119 | tokenizer ([`~transformers.CLIPTokenizer`]): 120 | A `CLIPTokenizer` to tokenize text. 121 | unet ([`UNet2DConditionModel`]): 122 | A `UNet2DConditionModel` to denoise the encoded image latents. 123 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): 124 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple 125 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined 126 | additional conditioning. 127 | scheduler ([`SchedulerMixin`]): 128 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 129 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 130 | safety_checker ([`StableDiffusionSafetyChecker`]): 131 | Classification module that estimates whether generated images could be considered offensive or harmful. 132 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 133 | about a model's potential harms. 134 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 135 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 136 | """ 137 | 138 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" 139 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 140 | _exclude_from_cpu_offload = ["safety_checker"] 141 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 142 | 143 | 144 | 145 | def __init__( 146 | self, 147 | # vae_2: CustomAutoencoderKL, 148 | vae: AutoencoderKL, 149 | text_encoder: CLIPTextModel, 150 | tokenizer: CLIPTokenizer, 151 | unet: UNet2DConditionModel, 152 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], 153 | scheduler: Union[DDIMScheduler], 154 | safety_checker: StableDiffusionSafetyChecker, 155 | feature_extractor: CLIPImageProcessor, 156 | image_encoder: CLIPVisionModelWithProjection = None, 157 | requires_safety_checker: bool = True, 158 | default_denoising_steps: Optional[int] = 1, 159 | default_processing_resolution: Optional[int] = 768, 160 | prompt="remove glass reflection", 161 | empty_text_embedding=None, 162 | t_start: Optional[int] = 0, 163 | ): 164 | super().__init__( 165 | vae, 166 | text_encoder, 167 | tokenizer, 168 | unet, 169 | controlnet, 170 | scheduler, 171 | safety_checker, 172 | feature_extractor, 173 | image_encoder, 174 | requires_safety_checker, 175 | ) 176 | # self.vae_2 = vae_2 177 | 178 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 179 | self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 180 | self.default_denoising_steps = default_denoising_steps 181 | self.default_processing_resolution = default_processing_resolution 182 | self.prompt = prompt 183 | self.prompt_embeds = None 184 | self.empty_text_embedding = empty_text_embedding 185 | self.t_start= t_start # target_out latents 186 | 187 | 188 | def check_inputs( 189 | self, 190 | image: PipelineImageInput, 191 | num_inference_steps: int, 192 | ensemble_size: int, 193 | processing_resolution: int, 194 | resample_method_input: str, 195 | resample_method_output: str, 196 | batch_size: int, 197 | ensembling_kwargs: Optional[Dict[str, Any]], 198 | latents: Optional[torch.Tensor], 199 | generator: Optional[Union[torch.Generator, List[torch.Generator]]], 200 | output_type: str, 201 | output_uncertainty: bool, 202 | ) -> int: 203 | if num_inference_steps is None: 204 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") 205 | if num_inference_steps < 1: 206 | raise ValueError("`num_inference_steps` must be positive.") 207 | if ensemble_size < 1: 208 | raise ValueError("`ensemble_size` must be positive.") 209 | if ensemble_size == 2: 210 | logger.warning( 211 | "`ensemble_size` == 2 results are similar to no ensembling (1); " 212 | "consider increasing the value to at least 3." 213 | ) 214 | if ensemble_size == 1 and output_uncertainty: 215 | raise ValueError( 216 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " 217 | "greater than 1." 218 | ) 219 | if processing_resolution is None: 220 | raise ValueError( 221 | "`processing_resolution` is not specified and could not be resolved from the model config." 222 | ) 223 | if processing_resolution < 0: 224 | raise ValueError( 225 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " 226 | "downsampled processing." 227 | ) 228 | if processing_resolution % self.vae_scale_factor != 0: 229 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") 230 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 231 | raise ValueError( 232 | "`resample_method_input` takes string values compatible with PIL library: " 233 | "nearest, nearest-exact, bilinear, bicubic, area." 234 | ) 235 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 236 | raise ValueError( 237 | "`resample_method_output` takes string values compatible with PIL library: " 238 | "nearest, nearest-exact, bilinear, bicubic, area." 239 | ) 240 | if batch_size < 1: 241 | raise ValueError("`batch_size` must be positive.") 242 | if output_type not in ["pt", "np"]: 243 | raise ValueError("`output_type` must be one of `pt` or `np`.") 244 | if latents is not None and generator is not None: 245 | raise ValueError("`latents` and `generator` cannot be used together.") 246 | if ensembling_kwargs is not None: 247 | if not isinstance(ensembling_kwargs, dict): 248 | raise ValueError("`ensembling_kwargs` must be a dictionary.") 249 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): 250 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") 251 | 252 | # image checks 253 | num_images = 0 254 | W, H = None, None 255 | if not isinstance(image, list): 256 | image = [image] 257 | for i, img in enumerate(image): 258 | if isinstance(img, np.ndarray) or torch.is_tensor(img): 259 | if img.ndim not in (2, 3, 4): 260 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") 261 | H_i, W_i = img.shape[-2:] 262 | N_i = 1 263 | if img.ndim == 4: 264 | N_i = img.shape[0] 265 | elif isinstance(img, Image.Image): 266 | W_i, H_i = img.size 267 | N_i = 1 268 | else: 269 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") 270 | if W is None: 271 | W, H = W_i, H_i 272 | elif (W, H) != (W_i, H_i): 273 | raise ValueError( 274 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" 275 | ) 276 | num_images += N_i 277 | 278 | # latents checks 279 | if latents is not None: 280 | if not torch.is_tensor(latents): 281 | raise ValueError("`latents` must be a torch.Tensor.") 282 | if latents.dim() != 4: 283 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") 284 | 285 | if processing_resolution > 0: 286 | max_orig = max(H, W) 287 | new_H = H * processing_resolution // max_orig 288 | new_W = W * processing_resolution // max_orig 289 | if new_H == 0 or new_W == 0: 290 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") 291 | W, H = new_W, new_H 292 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor 293 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor 294 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) 295 | 296 | if latents.shape != shape_expected: 297 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") 298 | 299 | # generator checks 300 | if generator is not None: 301 | if isinstance(generator, list): 302 | if len(generator) != num_images * ensemble_size: 303 | raise ValueError( 304 | "The number of generators must match the total number of ensemble members for all input images." 305 | ) 306 | if not all(g.device.type == generator[0].device.type for g in generator): 307 | raise ValueError("`generator` device placement is not consistent in the list.") 308 | elif not isinstance(generator, torch.Generator): 309 | raise ValueError(f"Unsupported generator type: {type(generator)}.") 310 | 311 | return num_images 312 | 313 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True): 314 | if not hasattr(self, "_progress_bar_config"): 315 | self._progress_bar_config = {} 316 | elif not isinstance(self._progress_bar_config, dict): 317 | raise ValueError( 318 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 319 | ) 320 | 321 | progress_bar_config = dict(**self._progress_bar_config) 322 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc) 323 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave) 324 | if iterable is not None: 325 | return tqdm(iterable, **progress_bar_config) 326 | elif total is not None: 327 | return tqdm(total=total, **progress_bar_config) 328 | else: 329 | raise ValueError("Either `total` or `iterable` has to be defined.") 330 | 331 | @torch.no_grad() 332 | @replace_example_docstring(EXAMPLE_DOC_STRING) 333 | def __call__( 334 | self, 335 | image: PipelineImageInput, 336 | vae_2: CustomAutoencoderKL, 337 | prompt: Union[str, List[str]] = None, 338 | negative_prompt: Optional[Union[str, List[str]]] = None, 339 | num_inference_steps: Optional[int] = None, 340 | ensemble_size: int = 1, 341 | processing_resolution: Optional[int] = None, 342 | match_input_resolution: bool = True, 343 | resample_method_input: str = "bilinear", 344 | resample_method_output: str = "bilinear", 345 | batch_size: int = 1, 346 | ensembling_kwargs: Optional[Dict[str, Any]] = None, 347 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 348 | prompt_embeds: Optional[torch.Tensor] = None, 349 | negative_prompt_embeds: Optional[torch.Tensor] = None, 350 | num_images_per_prompt: Optional[int] = 1, 351 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 352 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 353 | output_type: str = "np", 354 | output_uncertainty: bool = False, 355 | output_latent: bool = False, 356 | skip_preprocess: bool = False, 357 | return_dict: bool = True, 358 | **kwargs, 359 | ): 360 | """ 361 | Function invoked when calling the pipeline. 362 | 363 | Args: 364 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), 365 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For 366 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible 367 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or 368 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the 369 | same width and height. 370 | num_inference_steps (`int`, *optional*, defaults to `None`): 371 | Number of denoising diffusion steps during inference. The default value `None` results in automatic 372 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 373 | for Marigold-LCM models. 374 | ensemble_size (`int`, defaults to `1`): 375 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for 376 | faster inference. 377 | processing_resolution (`int`, *optional*, defaults to `None`): 378 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This 379 | produces crisper predictions, but may also lead to the overall loss of global context. The default 380 | value `None` resolves to the optimal value from the model config. 381 | match_input_resolution (`bool`, *optional*, defaults to `True`): 382 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer 383 | side of the output will equal to `processing_resolution`. 384 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`): 385 | Resampling method used to resize input images to `processing_resolution`. The accepted values are: 386 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 387 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`): 388 | Resampling method used to resize output predictions to match the input resolution. The accepted values 389 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 390 | batch_size (`int`, *optional*, defaults to `1`): 391 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images. 392 | ensembling_kwargs (`dict`, *optional*, defaults to `None`) 393 | Extra dictionary with arguments for precise ensembling control. The following options are available: 394 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in 395 | every pixel location, can be either `"closest"` or `"mean"`. 396 | latents (`torch.Tensor`, *optional*, defaults to `None`): 397 | Latent noise tensors to replace the random initialization. These can be taken from the previous 398 | function call's output. 399 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): 400 | Random number generator object to ensure reproducibility. 401 | output_type (`str`, *optional*, defaults to `"np"`): 402 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted 403 | values are: `"np"` (numpy array) or `"pt"` (torch tensor). 404 | output_uncertainty (`bool`, *optional*, defaults to `False`): 405 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that 406 | the `ensemble_size` argument is set to a value above 2. 407 | output_latent (`bool`, *optional*, defaults to `False`): 408 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions 409 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the 410 | `latents` argument. 411 | return_dict (`bool`, *optional*, defaults to `True`): 412 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. 413 | 414 | Examples: 415 | 416 | Returns: 417 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: 418 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a 419 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty 420 | (or `None`), and the third is the latent (or `None`). 421 | """ 422 | 423 | # 0. Resolving variables. 424 | device = self._execution_device 425 | dtype = self.dtype 426 | 427 | # Model-specific optimal default values leading to fast and reasonable results. 428 | if num_inference_steps is None: 429 | num_inference_steps = self.default_denoising_steps 430 | if processing_resolution is None: 431 | processing_resolution = self.default_processing_resolution 432 | 433 | 434 | # 1. Check inputs. 435 | num_images = self.check_inputs( 436 | image, 437 | num_inference_steps, 438 | ensemble_size, 439 | processing_resolution, 440 | resample_method_input, 441 | resample_method_output, 442 | batch_size, 443 | ensembling_kwargs, 444 | latents, 445 | generator, 446 | output_type, 447 | output_uncertainty, 448 | ) 449 | 450 | 451 | # 2. Prepare empty text conditioning. 452 | # Model invocation: self.tokenizer, self.text_encoder. 453 | if self.empty_text_embedding is None: 454 | prompt = "" 455 | text_inputs = self.tokenizer( 456 | prompt, 457 | padding="do_not_pad", 458 | max_length=self.tokenizer.model_max_length, 459 | truncation=True, 460 | return_tensors="pt", 461 | ) 462 | text_input_ids = text_inputs.input_ids.to(device) 463 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] 464 | 465 | 466 | 467 | # 3. prepare prompt 468 | if self.prompt_embeds is None: 469 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 470 | self.prompt, 471 | device, 472 | num_images_per_prompt, 473 | False, 474 | negative_prompt, 475 | prompt_embeds=prompt_embeds, 476 | negative_prompt_embeds=None, 477 | lora_scale=None, 478 | clip_skip=None, 479 | ) 480 | self.prompt_embeds = prompt_embeds 481 | self.negative_prompt_embeds = negative_prompt_embeds 482 | 483 | 484 | 485 | # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, 486 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where 487 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are 488 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` 489 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of 490 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing 491 | # resolution can lead to loss of either fine details or global context in the output predictions. 492 | if not skip_preprocess: 493 | image, padding, original_resolution = self.image_processor.preprocess( 494 | image, processing_resolution, resample_method_input, device, dtype 495 | ) # [N,3,PPH,PPW] 496 | else: 497 | padding = (0, 0) 498 | original_resolution = image.shape[2:] 499 | # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` 500 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. 501 | # Latents of each such predictions across all input images and all ensemble members are represented in the 502 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded 503 | # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure 504 | # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline 505 | # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken 506 | # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled 507 | # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space 508 | # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. 509 | # Model invocation: self.vae.encoder. 510 | image_latent, pred_latent = self.prepare_latents( 511 | image, latents, generator, ensemble_size, batch_size 512 | ) # [N*E,4,h,w], [N*E,4,h,w] 513 | 514 | gaus_noise = pred_latent.detach().clone() 515 | # del image 516 | 517 | 518 | # 6. obtain control_output 519 | 520 | cond_scale =controlnet_conditioning_scale 521 | down_block_res_samples, mid_block_res_sample = self.controlnet( 522 | image_latent.detach(), 523 | self.t_start, 524 | encoder_hidden_states=self.prompt_embeds, 525 | conditioning_scale=cond_scale, 526 | guess_mode=False, 527 | return_dict=False, 528 | ) 529 | 530 | # 7. Onestep sampling 531 | latent_x_t = self.unet( 532 | pred_latent, 533 | self.t_start, 534 | encoder_hidden_states=self.prompt_embeds, 535 | down_block_additional_residuals=down_block_res_samples, 536 | mid_block_additional_residual=mid_block_res_sample, 537 | return_dict=False, 538 | )[0] 539 | 540 | 541 | del ( 542 | pred_latent, 543 | image_latent, 544 | ) 545 | 546 | # encoder 547 | skip_connections = vae_2.encode(image) 548 | # decoder 549 | prediction = self.decode_prediction(latent_x_t, skip_connections, vae_2) 550 | 551 | 552 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] 553 | 554 | prediction = self.image_processor.resize_antialias( 555 | prediction, original_resolution, resample_method_output, is_aa=False 556 | ) # [N,3,H,W] 557 | 558 | if output_type == "np": 559 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] 560 | 561 | # 11. Offload all models 562 | self.maybe_free_model_hooks() 563 | 564 | return DAIOutput( 565 | prediction=prediction, 566 | latent=latent_x_t, 567 | gaus_noise=gaus_noise, 568 | ) 569 | 570 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents 571 | def prepare_latents( 572 | self, 573 | image: torch.Tensor, 574 | latents: Optional[torch.Tensor], 575 | generator: Optional[torch.Generator], 576 | ensemble_size: int, 577 | batch_size: int, 578 | ) -> Tuple[torch.Tensor, torch.Tensor]: 579 | def retrieve_latents(encoder_output): 580 | if hasattr(encoder_output, "latent_dist"): 581 | return encoder_output.latent_dist.mode() 582 | elif hasattr(encoder_output, "latents"): 583 | return encoder_output.latents 584 | else: 585 | raise AttributeError("Could not access latents of provided encoder_output") 586 | 587 | 588 | 589 | image_latent = torch.cat( 590 | [ 591 | retrieve_latents(self.vae.encode(image[i : i + batch_size])) 592 | for i in range(0, image.shape[0], batch_size) 593 | ], 594 | dim=0, 595 | ) # [N,4,h,w] 596 | image_latent = image_latent * self.vae.config.scaling_factor 597 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] 598 | 599 | pred_latent = torch.zeros_like(image_latent) 600 | if pred_latent is None: 601 | pred_latent = randn_tensor( 602 | image_latent.shape, 603 | generator=generator, 604 | device=image_latent.device, 605 | dtype=image_latent.dtype, 606 | ) # [N*E,4,h,w] 607 | 608 | return image_latent, pred_latent 609 | 610 | def decode_prediction(self, pred_latent: torch.Tensor, skip_connections: list, vae_2: CustomAutoencoderKL) -> torch.Tensor: 611 | if pred_latent.dim() != 4 or pred_latent.shape[1] != vae_2.config.latent_channels: 612 | raise ValueError( 613 | f"Expecting 4D tensor of shape [B,{vae_2.config.latent_channels},H,W]; got {pred_latent.shape}." 614 | ) 615 | 616 | prediction = vae_2.decode(pred_latent / vae_2.config.scaling_factor, skip_connections, return_dict=False)[0] # [B,3,H,W] 617 | 618 | return prediction # [B,3,H,W] 619 | 620 | @staticmethod 621 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: 622 | if normals.dim() != 4 or normals.shape[1] != 3: 623 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 624 | 625 | norm = torch.norm(normals, dim=1, keepdim=True) 626 | normals /= norm.clamp(min=eps) 627 | 628 | return normals 629 | 630 | @staticmethod 631 | def ensemble_normals( 632 | normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" 633 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 634 | """ 635 | Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is 636 | the number of ensemble members for a given prediction of size `(H x W)`. 637 | 638 | Args: 639 | normals (`torch.Tensor`): 640 | Input ensemble normals maps. 641 | output_uncertainty (`bool`, *optional*, defaults to `False`): 642 | Whether to output uncertainty map. 643 | reduction (`str`, *optional*, defaults to `"closest"`): 644 | Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and 645 | `"mean"`. 646 | 647 | Returns: 648 | A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of 649 | uncertainties of shape `(1, 1, H, W)`. 650 | """ 651 | if normals.dim() != 4 or normals.shape[1] != 3: 652 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 653 | if reduction not in ("closest", "mean"): 654 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 655 | 656 | mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] 657 | mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] 658 | 659 | sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] 660 | sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 661 | 662 | uncertainty = None 663 | if output_uncertainty: 664 | uncertainty = sim_cos.arccos() # [E,1,H,W] 665 | uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] 666 | 667 | if reduction == "mean": 668 | return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] 669 | 670 | closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] 671 | closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] 672 | closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] 673 | 674 | return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] 675 | 676 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 677 | def retrieve_timesteps( 678 | scheduler, 679 | num_inference_steps: Optional[int] = None, 680 | device: Optional[Union[str, torch.device]] = None, 681 | timesteps: Optional[List[int]] = None, 682 | sigmas: Optional[List[float]] = None, 683 | **kwargs, 684 | ): 685 | """ 686 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 687 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 688 | 689 | Args: 690 | scheduler (`SchedulerMixin`): 691 | The scheduler to get timesteps from. 692 | num_inference_steps (`int`): 693 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 694 | must be `None`. 695 | device (`str` or `torch.device`, *optional*): 696 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 697 | timesteps (`List[int]`, *optional*): 698 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 699 | `num_inference_steps` and `sigmas` must be `None`. 700 | sigmas (`List[float]`, *optional*): 701 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 702 | `num_inference_steps` and `timesteps` must be `None`. 703 | 704 | Returns: 705 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 706 | second element is the number of inference steps. 707 | """ 708 | if timesteps is not None and sigmas is not None: 709 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 710 | if timesteps is not None: 711 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 712 | if not accepts_timesteps: 713 | raise ValueError( 714 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 715 | f" timestep schedules. Please check whether you are using the correct scheduler." 716 | ) 717 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 718 | timesteps = scheduler.timesteps 719 | num_inference_steps = len(timesteps) 720 | elif sigmas is not None: 721 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 722 | if not accept_sigmas: 723 | raise ValueError( 724 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 725 | f" sigmas schedules. Please check whether you are using the correct scheduler." 726 | ) 727 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 728 | timesteps = scheduler.timesteps 729 | num_inference_steps = len(timesteps) 730 | else: 731 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 732 | timesteps = scheduler.timesteps 733 | return timesteps, num_inference_steps 734 | -------------------------------------------------------------------------------- /DAI/pipeline_onestep.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # -------------------------------------------------------------------------- 16 | # More information and citation instructions are available on the 17 | # -------------------------------------------------------------------------- 18 | from dataclasses import dataclass 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | from PIL import Image 24 | from tqdm.auto import tqdm 25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection 26 | 27 | 28 | from diffusers.image_processor import PipelineImageInput 29 | from diffusers.models import ( 30 | AutoencoderKL, 31 | UNet2DConditionModel, 32 | ControlNetModel, 33 | ) 34 | from diffusers.schedulers import ( 35 | DDIMScheduler 36 | ) 37 | 38 | from diffusers.utils import ( 39 | BaseOutput, 40 | logging, 41 | replace_example_docstring, 42 | ) 43 | 44 | 45 | from diffusers.utils.torch_utils import randn_tensor 46 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline 47 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor 48 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 49 | 50 | import pdb 51 | 52 | 53 | 54 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 55 | 56 | 57 | EXAMPLE_DOC_STRING = """ 58 | Examples: 59 | ```py 60 | >>> import diffusers 61 | >>> import torch 62 | 63 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( 64 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 65 | ... ).to("cuda") 66 | 67 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") 68 | >>> normals = pipe(image) 69 | 70 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction) 71 | >>> vis[0].save("einstein_normals.png") 72 | ``` 73 | """ 74 | 75 | 76 | @dataclass 77 | class DAIOutput(BaseOutput): 78 | """ 79 | Output class for Marigold monocular normals prediction pipeline. 80 | 81 | Args: 82 | prediction (`np.ndarray`, `torch.Tensor`): 83 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height 84 | \times width$, regardless of whether the images were passed as a 4D array or a list. 85 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`): 86 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages 87 | \times 1 \times height \times width$. 88 | latent (`None`, `torch.Tensor`): 89 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. 90 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. 91 | """ 92 | 93 | prediction: Union[np.ndarray, torch.Tensor] 94 | latent: Union[None, torch.Tensor] 95 | gaus_noise: Union[None, torch.Tensor] 96 | 97 | 98 | class OneStepPipeline(StableDiffusionControlNetPipeline): 99 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. 100 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. 101 | 102 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 103 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 104 | 105 | The pipeline also inherits the following loading methods: 106 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings 107 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights 108 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights 109 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files 110 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters 111 | 112 | Args: 113 | vae ([`AutoencoderKL`]): 114 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 115 | text_encoder ([`~transformers.CLIPTextModel`]): 116 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 117 | tokenizer ([`~transformers.CLIPTokenizer`]): 118 | A `CLIPTokenizer` to tokenize text. 119 | unet ([`UNet2DConditionModel`]): 120 | A `UNet2DConditionModel` to denoise the encoded image latents. 121 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): 122 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple 123 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined 124 | additional conditioning. 125 | scheduler ([`SchedulerMixin`]): 126 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 127 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 128 | safety_checker ([`StableDiffusionSafetyChecker`]): 129 | Classification module that estimates whether generated images could be considered offensive or harmful. 130 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details 131 | about a model's potential harms. 132 | feature_extractor ([`~transformers.CLIPImageProcessor`]): 133 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. 134 | """ 135 | 136 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" 137 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] 138 | _exclude_from_cpu_offload = ["safety_checker"] 139 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 140 | 141 | 142 | 143 | def __init__( 144 | self, 145 | vae: AutoencoderKL, 146 | text_encoder: CLIPTextModel, 147 | tokenizer: CLIPTokenizer, 148 | unet: UNet2DConditionModel, 149 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]], 150 | scheduler: Union[DDIMScheduler], 151 | safety_checker: StableDiffusionSafetyChecker, 152 | feature_extractor: CLIPImageProcessor, 153 | image_encoder: CLIPVisionModelWithProjection = None, 154 | requires_safety_checker: bool = True, 155 | default_denoising_steps: Optional[int] = 1, 156 | default_processing_resolution: Optional[int] = 768, 157 | prompt="remove glass reflection", 158 | empty_text_embedding=None, 159 | t_start: Optional[int] = 401, 160 | ): 161 | super().__init__( 162 | vae, 163 | text_encoder, 164 | tokenizer, 165 | unet, 166 | controlnet, 167 | scheduler, 168 | safety_checker, 169 | feature_extractor, 170 | image_encoder, 171 | requires_safety_checker, 172 | ) 173 | 174 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 175 | self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 176 | self.default_denoising_steps = default_denoising_steps 177 | self.default_processing_resolution = default_processing_resolution 178 | self.prompt = prompt 179 | self.prompt_embeds = None 180 | self.empty_text_embedding = empty_text_embedding 181 | self.t_start= t_start # target_out latents 182 | 183 | def check_inputs( 184 | self, 185 | image: PipelineImageInput, 186 | num_inference_steps: int, 187 | ensemble_size: int, 188 | processing_resolution: int, 189 | resample_method_input: str, 190 | resample_method_output: str, 191 | batch_size: int, 192 | ensembling_kwargs: Optional[Dict[str, Any]], 193 | latents: Optional[torch.Tensor], 194 | generator: Optional[Union[torch.Generator, List[torch.Generator]]], 195 | output_type: str, 196 | output_uncertainty: bool, 197 | ) -> int: 198 | if num_inference_steps is None: 199 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") 200 | if num_inference_steps < 1: 201 | raise ValueError("`num_inference_steps` must be positive.") 202 | if ensemble_size < 1: 203 | raise ValueError("`ensemble_size` must be positive.") 204 | if ensemble_size == 2: 205 | logger.warning( 206 | "`ensemble_size` == 2 results are similar to no ensembling (1); " 207 | "consider increasing the value to at least 3." 208 | ) 209 | if ensemble_size == 1 and output_uncertainty: 210 | raise ValueError( 211 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " 212 | "greater than 1." 213 | ) 214 | if processing_resolution is None: 215 | raise ValueError( 216 | "`processing_resolution` is not specified and could not be resolved from the model config." 217 | ) 218 | if processing_resolution < 0: 219 | raise ValueError( 220 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " 221 | "downsampled processing." 222 | ) 223 | if processing_resolution % self.vae_scale_factor != 0: 224 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") 225 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 226 | raise ValueError( 227 | "`resample_method_input` takes string values compatible with PIL library: " 228 | "nearest, nearest-exact, bilinear, bicubic, area." 229 | ) 230 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 231 | raise ValueError( 232 | "`resample_method_output` takes string values compatible with PIL library: " 233 | "nearest, nearest-exact, bilinear, bicubic, area." 234 | ) 235 | if batch_size < 1: 236 | raise ValueError("`batch_size` must be positive.") 237 | if output_type not in ["pt", "np"]: 238 | raise ValueError("`output_type` must be one of `pt` or `np`.") 239 | if latents is not None and generator is not None: 240 | raise ValueError("`latents` and `generator` cannot be used together.") 241 | if ensembling_kwargs is not None: 242 | if not isinstance(ensembling_kwargs, dict): 243 | raise ValueError("`ensembling_kwargs` must be a dictionary.") 244 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): 245 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") 246 | 247 | # image checks 248 | num_images = 0 249 | W, H = None, None 250 | if not isinstance(image, list): 251 | image = [image] 252 | for i, img in enumerate(image): 253 | if isinstance(img, np.ndarray) or torch.is_tensor(img): 254 | if img.ndim not in (2, 3, 4): 255 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") 256 | H_i, W_i = img.shape[-2:] 257 | N_i = 1 258 | if img.ndim == 4: 259 | N_i = img.shape[0] 260 | elif isinstance(img, Image.Image): 261 | W_i, H_i = img.size 262 | N_i = 1 263 | else: 264 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") 265 | if W is None: 266 | W, H = W_i, H_i 267 | elif (W, H) != (W_i, H_i): 268 | raise ValueError( 269 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" 270 | ) 271 | num_images += N_i 272 | 273 | # latents checks 274 | if latents is not None: 275 | if not torch.is_tensor(latents): 276 | raise ValueError("`latents` must be a torch.Tensor.") 277 | if latents.dim() != 4: 278 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") 279 | 280 | if processing_resolution > 0: 281 | max_orig = max(H, W) 282 | new_H = H * processing_resolution // max_orig 283 | new_W = W * processing_resolution // max_orig 284 | if new_H == 0 or new_W == 0: 285 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") 286 | W, H = new_W, new_H 287 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor 288 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor 289 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) 290 | 291 | if latents.shape != shape_expected: 292 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") 293 | 294 | # generator checks 295 | if generator is not None: 296 | if isinstance(generator, list): 297 | if len(generator) != num_images * ensemble_size: 298 | raise ValueError( 299 | "The number of generators must match the total number of ensemble members for all input images." 300 | ) 301 | if not all(g.device.type == generator[0].device.type for g in generator): 302 | raise ValueError("`generator` device placement is not consistent in the list.") 303 | elif not isinstance(generator, torch.Generator): 304 | raise ValueError(f"Unsupported generator type: {type(generator)}.") 305 | 306 | return num_images 307 | 308 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True): 309 | if not hasattr(self, "_progress_bar_config"): 310 | self._progress_bar_config = {} 311 | elif not isinstance(self._progress_bar_config, dict): 312 | raise ValueError( 313 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 314 | ) 315 | 316 | progress_bar_config = dict(**self._progress_bar_config) 317 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc) 318 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave) 319 | if iterable is not None: 320 | return tqdm(iterable, **progress_bar_config) 321 | elif total is not None: 322 | return tqdm(total=total, **progress_bar_config) 323 | else: 324 | raise ValueError("Either `total` or `iterable` has to be defined.") 325 | 326 | @torch.no_grad() 327 | @replace_example_docstring(EXAMPLE_DOC_STRING) 328 | def __call__( 329 | self, 330 | image: PipelineImageInput, 331 | prompt: Union[str, List[str]] = None, 332 | negative_prompt: Optional[Union[str, List[str]]] = None, 333 | num_inference_steps: Optional[int] = None, 334 | ensemble_size: int = 1, 335 | processing_resolution: Optional[int] = None, 336 | match_input_resolution: bool = True, 337 | resample_method_input: str = "bilinear", 338 | resample_method_output: str = "bilinear", 339 | batch_size: int = 1, 340 | ensembling_kwargs: Optional[Dict[str, Any]] = None, 341 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 342 | prompt_embeds: Optional[torch.Tensor] = None, 343 | negative_prompt_embeds: Optional[torch.Tensor] = None, 344 | num_images_per_prompt: Optional[int] = 1, 345 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 346 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 347 | output_type: str = "np", 348 | output_uncertainty: bool = False, 349 | output_latent: bool = False, 350 | skip_preprocess: bool = False, 351 | return_dict: bool = True, 352 | **kwargs, 353 | ): 354 | """ 355 | Function invoked when calling the pipeline. 356 | 357 | Args: 358 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), 359 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For 360 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible 361 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or 362 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the 363 | same width and height. 364 | num_inference_steps (`int`, *optional*, defaults to `None`): 365 | Number of denoising diffusion steps during inference. The default value `None` results in automatic 366 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 367 | for Marigold-LCM models. 368 | ensemble_size (`int`, defaults to `1`): 369 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for 370 | faster inference. 371 | processing_resolution (`int`, *optional*, defaults to `None`): 372 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This 373 | produces crisper predictions, but may also lead to the overall loss of global context. The default 374 | value `None` resolves to the optimal value from the model config. 375 | match_input_resolution (`bool`, *optional*, defaults to `True`): 376 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer 377 | side of the output will equal to `processing_resolution`. 378 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`): 379 | Resampling method used to resize input images to `processing_resolution`. The accepted values are: 380 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 381 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`): 382 | Resampling method used to resize output predictions to match the input resolution. The accepted values 383 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 384 | batch_size (`int`, *optional*, defaults to `1`): 385 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images. 386 | ensembling_kwargs (`dict`, *optional*, defaults to `None`) 387 | Extra dictionary with arguments for precise ensembling control. The following options are available: 388 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in 389 | every pixel location, can be either `"closest"` or `"mean"`. 390 | latents (`torch.Tensor`, *optional*, defaults to `None`): 391 | Latent noise tensors to replace the random initialization. These can be taken from the previous 392 | function call's output. 393 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): 394 | Random number generator object to ensure reproducibility. 395 | output_type (`str`, *optional*, defaults to `"np"`): 396 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted 397 | values are: `"np"` (numpy array) or `"pt"` (torch tensor). 398 | output_uncertainty (`bool`, *optional*, defaults to `False`): 399 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that 400 | the `ensemble_size` argument is set to a value above 2. 401 | output_latent (`bool`, *optional*, defaults to `False`): 402 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions 403 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the 404 | `latents` argument. 405 | return_dict (`bool`, *optional*, defaults to `True`): 406 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. 407 | 408 | Examples: 409 | 410 | Returns: 411 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: 412 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a 413 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty 414 | (or `None`), and the third is the latent (or `None`). 415 | """ 416 | 417 | # 0. Resolving variables. 418 | device = self._execution_device 419 | dtype = self.dtype 420 | 421 | # Model-specific optimal default values leading to fast and reasonable results. 422 | if num_inference_steps is None: 423 | num_inference_steps = self.default_denoising_steps 424 | if processing_resolution is None: 425 | processing_resolution = self.default_processing_resolution 426 | 427 | # 1. Check inputs. 428 | num_images = self.check_inputs( 429 | image, 430 | num_inference_steps, 431 | ensemble_size, 432 | processing_resolution, 433 | resample_method_input, 434 | resample_method_output, 435 | batch_size, 436 | ensembling_kwargs, 437 | latents, 438 | generator, 439 | output_type, 440 | output_uncertainty, 441 | ) 442 | 443 | 444 | # 2. Prepare empty text conditioning. 445 | # Model invocation: self.tokenizer, self.text_encoder. 446 | if self.empty_text_embedding is None: 447 | prompt = "" 448 | text_inputs = self.tokenizer( 449 | prompt, 450 | padding="do_not_pad", 451 | max_length=self.tokenizer.model_max_length, 452 | truncation=True, 453 | return_tensors="pt", 454 | ) 455 | text_input_ids = text_inputs.input_ids.to(device) 456 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] 457 | 458 | 459 | 460 | # 3. prepare prompt 461 | if self.prompt_embeds is None: 462 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 463 | self.prompt, 464 | device, 465 | num_images_per_prompt, 466 | False, 467 | negative_prompt, 468 | prompt_embeds=prompt_embeds, 469 | negative_prompt_embeds=None, 470 | lora_scale=None, 471 | clip_skip=None, 472 | ) 473 | self.prompt_embeds = prompt_embeds 474 | self.negative_prompt_embeds = negative_prompt_embeds 475 | 476 | 477 | 478 | # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, 479 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where 480 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are 481 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` 482 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of 483 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing 484 | # resolution can lead to loss of either fine details or global context in the output predictions. 485 | if not skip_preprocess: 486 | image, padding, original_resolution = self.image_processor.preprocess( 487 | image, processing_resolution, resample_method_input, device, dtype 488 | ) # [N,3,PPH,PPW] 489 | else: 490 | padding = (0, 0) 491 | original_resolution = image.shape[2:] 492 | # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` 493 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. 494 | # Latents of each such predictions across all input images and all ensemble members are represented in the 495 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded 496 | # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure 497 | # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline 498 | # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken 499 | # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled 500 | # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space 501 | # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. 502 | # Model invocation: self.vae.encoder. 503 | image_latent, pred_latent = self.prepare_latents( 504 | image, latents, generator, ensemble_size, batch_size 505 | ) # [N*E,4,h,w], [N*E,4,h,w] 506 | 507 | gaus_noise = pred_latent.detach().clone() 508 | del image 509 | 510 | 511 | # 6. obtain control_output 512 | 513 | cond_scale =controlnet_conditioning_scale 514 | down_block_res_samples, mid_block_res_sample = self.controlnet( 515 | image_latent.detach(), 516 | self.t_start, 517 | encoder_hidden_states=self.prompt_embeds, 518 | conditioning_scale=cond_scale, 519 | guess_mode=False, 520 | return_dict=False, 521 | ) 522 | 523 | # 7. Onestep sampling 524 | latent_x_t = self.unet( 525 | pred_latent, 526 | self.t_start, 527 | encoder_hidden_states=self.prompt_embeds, 528 | down_block_additional_residuals=down_block_res_samples, 529 | mid_block_additional_residual=mid_block_res_sample, 530 | return_dict=False, 531 | )[0] 532 | 533 | 534 | del ( 535 | pred_latent, 536 | image_latent, 537 | ) 538 | 539 | # decoder 540 | prediction = self.decode_prediction(latent_x_t) 541 | 542 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] 543 | 544 | prediction = self.image_processor.resize_antialias( 545 | prediction, original_resolution, resample_method_output, is_aa=False 546 | ) # [N,3,H,W] 547 | 548 | if output_type == "np": 549 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] 550 | 551 | # 11. Offload all models 552 | self.maybe_free_model_hooks() 553 | 554 | return DAIOutput( 555 | prediction=prediction, 556 | latent=latent_x_t, 557 | gaus_noise=gaus_noise, 558 | ) 559 | 560 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents 561 | def prepare_latents( 562 | self, 563 | image: torch.Tensor, 564 | latents: Optional[torch.Tensor], 565 | generator: Optional[torch.Generator], 566 | ensemble_size: int, 567 | batch_size: int, 568 | ) -> Tuple[torch.Tensor, torch.Tensor]: 569 | def retrieve_latents(encoder_output): 570 | if hasattr(encoder_output, "latent_dist"): 571 | return encoder_output.latent_dist.mode() 572 | elif hasattr(encoder_output, "latents"): 573 | return encoder_output.latents 574 | else: 575 | raise AttributeError("Could not access latents of provided encoder_output") 576 | 577 | 578 | 579 | image_latent = torch.cat( 580 | [ 581 | retrieve_latents(self.vae.encode(image[i : i + batch_size])) 582 | for i in range(0, image.shape[0], batch_size) 583 | ], 584 | dim=0, 585 | ) # [N,4,h,w] 586 | image_latent = image_latent * self.vae.config.scaling_factor 587 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] 588 | 589 | pred_latent = torch.zeros_like(image_latent) 590 | if pred_latent is None: 591 | pred_latent = randn_tensor( 592 | image_latent.shape, 593 | generator=generator, 594 | device=image_latent.device, 595 | dtype=image_latent.dtype, 596 | ) # [N*E,4,h,w] 597 | 598 | return image_latent, pred_latent 599 | 600 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: 601 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: 602 | raise ValueError( 603 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." 604 | ) 605 | 606 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] 607 | 608 | return prediction # [B,3,H,W] 609 | 610 | @staticmethod 611 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: 612 | if normals.dim() != 4 or normals.shape[1] != 3: 613 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 614 | 615 | norm = torch.norm(normals, dim=1, keepdim=True) 616 | normals /= norm.clamp(min=eps) 617 | 618 | return normals 619 | 620 | @staticmethod 621 | def ensemble_normals( 622 | normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" 623 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 624 | """ 625 | Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is 626 | the number of ensemble members for a given prediction of size `(H x W)`. 627 | 628 | Args: 629 | normals (`torch.Tensor`): 630 | Input ensemble normals maps. 631 | output_uncertainty (`bool`, *optional*, defaults to `False`): 632 | Whether to output uncertainty map. 633 | reduction (`str`, *optional*, defaults to `"closest"`): 634 | Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and 635 | `"mean"`. 636 | 637 | Returns: 638 | A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of 639 | uncertainties of shape `(1, 1, H, W)`. 640 | """ 641 | if normals.dim() != 4 or normals.shape[1] != 3: 642 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") 643 | if reduction not in ("closest", "mean"): 644 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 645 | 646 | mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] 647 | mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] 648 | 649 | sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] 650 | sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 651 | 652 | uncertainty = None 653 | if output_uncertainty: 654 | uncertainty = sim_cos.arccos() # [E,1,H,W] 655 | uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] 656 | 657 | if reduction == "mean": 658 | return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] 659 | 660 | closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] 661 | closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] 662 | closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] 663 | 664 | return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] 665 | 666 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 667 | def retrieve_timesteps( 668 | scheduler, 669 | num_inference_steps: Optional[int] = None, 670 | device: Optional[Union[str, torch.device]] = None, 671 | timesteps: Optional[List[int]] = None, 672 | sigmas: Optional[List[float]] = None, 673 | **kwargs, 674 | ): 675 | """ 676 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 677 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 678 | 679 | Args: 680 | scheduler (`SchedulerMixin`): 681 | The scheduler to get timesteps from. 682 | num_inference_steps (`int`): 683 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 684 | must be `None`. 685 | device (`str` or `torch.device`, *optional*): 686 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 687 | timesteps (`List[int]`, *optional*): 688 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 689 | `num_inference_steps` and `sigmas` must be `None`. 690 | sigmas (`List[float]`, *optional*): 691 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 692 | `num_inference_steps` and `timesteps` must be `None`. 693 | 694 | Returns: 695 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 696 | second element is the number of inference steps. 697 | """ 698 | if timesteps is not None and sigmas is not None: 699 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 700 | if timesteps is not None: 701 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 702 | if not accepts_timesteps: 703 | raise ValueError( 704 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 705 | f" timestep schedules. Please check whether you are using the correct scheduler." 706 | ) 707 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 708 | timesteps = scheduler.timesteps 709 | num_inference_steps = len(timesteps) 710 | elif sigmas is not None: 711 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 712 | if not accept_sigmas: 713 | raise ValueError( 714 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 715 | f" sigmas schedules. Please check whether you are using the correct scheduler." 716 | ) 717 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 718 | timesteps = scheduler.timesteps 719 | num_inference_steps = len(timesteps) 720 | else: 721 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 722 | timesteps = scheduler.timesteps 723 | return timesteps, num_inference_steps 724 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # Dereflection Any Image with Diffusion Priors and Diversified Data 6 | ### [Project Page](https://abuuu122.github.io/DAI.github.io/) | [Paper](https://arxiv.org/abs/2503.17347) | [Data (coming soon)]() | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/129uKcCNfoR2sIn5RifqhYGpB0xa2tdsH?usp=sharing) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Space-yellow)](https://huggingface.co/spaces/sjtu-deepvision/Dereflection-Any-Image) | [![Hugging Face Model](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Model-green)](https://huggingface.co/sjtu-deepvision/dereflection-any-image-v0) 7 | 8 | 📖[Dereflection Any Image with Diffusion Priors and Diversified Data](https://abuuu122.github.io/DAI.github.io/) 9 | 10 | [Jichen Hu](https://abuuu122.github.io/DAI.github.io/)1*, [Chen Yang](https://scholar.google.com/citations?hl=zh-CN&user=StdXTR8AAAAJ)1*, [Zanwei Zhou](https://abuuu122.github.io/DAI.github.io/)1, [Jiemin Fang](https://jaminfong.cn/)2†, [Xiaokang Yang](https://abuuu122.github.io/DAI.github.io/)1, [Qi Tian](https://www.qitian1987.com/)2, [Wei Shen](https://shenwei1231.github.io/)1✉†, 11 | 1MoE Key Lab of Artificial Intelligence, AI Institute, SJTU   2Huawei Inc.   12 | *Equal contribution.   Project lead.   Corresponding author. 13 | 14 |
15 | 16 |
17 | 18 | ## 📝 Todo 19 | 20 | - [x] Release inference code 21 | - [x] Release pretrained model weights 22 | - [x] Release project page 23 | - [x] Release paper 24 | - [ ] Release dataset 25 | - [ ] Release training code 26 | 27 | ## 🚀Setup 28 | 29 | ### Environment 30 | Dereflection Any Image is tested with CUDA 11.8 and python 3.9. All the required packages are listed in `requirements.txt`. You can install them with 31 | 32 | ```sh 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 37 | ### Weights 38 | Our scripts will automatically download the pretrained weights from Huggingface. 39 | 40 | You can also download the weights by [Google Drive](https://drive.google.com/drive/folders/1WFczJ0LgVbFfVQym7FLGW-f5iU7G1Rr-?usp=drive_link) or [Huggingface](https://huggingface.co/JichenHu/dereflection-any-image-v0) 41 | 42 | ## 💪Usage 43 | 44 | ### Inference 45 | Put images in the "input" dir, and run: 46 | ```sh 47 | python run.py --input_dir ./input/ --result_dir ./result/ --concat_dir ./concat/ 48 | ``` 49 | or use script run.sh directly. 50 | 51 | ### Gradio Demo 52 | ``` 53 | python demo.py 54 | ``` 55 | 56 | ## 🌏 Citation 57 | 58 | If you find Derefelection Any Image useful for your work please cite: 59 | 60 | ```text 61 | @misc{hu2025dereflection, 62 | title={Dereflection Any Image with Diffusion Priors and Diversified Data}, 63 | author={Jichen Hu and Chen Yang and Zanwei Zhou and Jiemin Fang and Xiaokang Yang and Qi Tian and Wei Shen}, 64 | year={2025}, 65 | eprint={2503.17347}, 66 | archivePrefix={arXiv}, 67 | primaryClass={cs.CV} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/assets/logo.png -------------------------------------------------------------------------------- /assets/logo_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/assets/logo_old.png -------------------------------------------------------------------------------- /assets/teaser.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/assets/teaser.mp4 -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/assets/teaser.png -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | from __future__ import annotations 20 | 21 | import functools 22 | import os 23 | import tempfile 24 | 25 | import gradio as gr 26 | import imageio as imageio 27 | import numpy as np 28 | import spaces 29 | import torch as torch 30 | torch.backends.cuda.matmul.allow_tf32 = True 31 | from PIL import Image 32 | from gradio_imageslider import ImageSlider 33 | 34 | from pathlib import Path 35 | import gradio 36 | from gradio.utils import get_cache_folder 37 | 38 | from DAI.pipeline_all import DAIPipeline 39 | 40 | from DAI.controlnetvae import ControlNetVAEModel 41 | 42 | from DAI.decoder import CustomAutoencoderKL 43 | 44 | from diffusers import ( 45 | AutoencoderKL, 46 | UNet2DConditionModel, 47 | ) 48 | 49 | from transformers import CLIPTextModel, AutoTokenizer 50 | 51 | 52 | class Examples(gradio.helpers.Examples): 53 | def __init__(self, *args, directory_name=None, **kwargs): 54 | super().__init__(*args, **kwargs, _initiated_directly=False) 55 | if directory_name is not None: 56 | self.cached_folder = get_cache_folder() / directory_name 57 | self.cached_file = Path(self.cached_folder) / "log.csv" 58 | self.create() 59 | 60 | 61 | def process_image_check(path_input): 62 | if path_input is None: 63 | raise gr.Error( 64 | "Missing image in the first pane: upload a file or use one from the gallery below." 65 | ) 66 | 67 | def process_image( 68 | pipe, 69 | vae_2, 70 | path_input, 71 | ): 72 | name_base, name_ext = os.path.splitext(os.path.basename(path_input)) 73 | print(f"Processing image {name_base}{name_ext}") 74 | 75 | path_output_dir = tempfile.mkdtemp() 76 | path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png") 77 | input_image = Image.open(path_input) 78 | resolution = None 79 | 80 | pipe_out = pipe( 81 | image=input_image, 82 | prompt="remove glass reflection", 83 | vae_2=vae_2, 84 | processing_resolution=resolution, 85 | ) 86 | 87 | processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2 88 | processed_frame = (processed_frame[0] * 255).astype(np.uint8) 89 | processed_frame = Image.fromarray(processed_frame) 90 | processed_frame.save(path_out_png) 91 | yield [input_image, path_out_png] 92 | 93 | 94 | def run_demo_server(pipe, vae_2): 95 | process_pipe_image = spaces.GPU(functools.partial(process_image, pipe, vae_2)) 96 | 97 | gradio_theme = gr.themes.Default() 98 | 99 | with gr.Blocks( 100 | theme=gradio_theme, 101 | title="DAI", 102 | css=""" 103 | #download { 104 | height: 118px; 105 | } 106 | .slider .inner { 107 | width: 5px; 108 | background: #FFF; 109 | } 110 | .viewport { 111 | aspect-ratio: 4/3; 112 | } 113 | .tabs button.selected { 114 | font-size: 20px !important; 115 | color: crimson !important; 116 | } 117 | h1 { 118 | text-align: center; 119 | display: block; 120 | } 121 | h2 { 122 | text-align: center; 123 | display: block; 124 | } 125 | h3 { 126 | text-align: center; 127 | display: block; 128 | } 129 | .md_feedback li { 130 | margin-bottom: 0px !important; 131 | } 132 | """, 133 | head=""" 134 | 135 | 141 | """, 142 | ) as demo: 143 | gr.Markdown( 144 | """ 145 | # Dereflection Any Image 146 |

147 | """ 148 | ) 149 | 150 | with gr.Tabs(elem_classes=["tabs"]): 151 | with gr.Tab("Image"): 152 | with gr.Row(): 153 | with gr.Column(): 154 | image_input = gr.Image( 155 | label="Input Image", 156 | type="filepath", 157 | ) 158 | with gr.Row(): 159 | image_submit_btn = gr.Button( 160 | value="Dereflection", variant="primary" 161 | ) 162 | image_reset_btn = gr.Button(value="Reset") 163 | with gr.Column(): 164 | image_output_slider = ImageSlider( 165 | label="outputs", 166 | type="filepath", 167 | show_download_button=True, 168 | show_share_button=True, 169 | interactive=False, 170 | elem_classes="slider", 171 | # position=0.25, 172 | ) 173 | 174 | Examples( 175 | fn=process_pipe_image, 176 | examples=sorted([ 177 | os.path.join("files", "image", name) 178 | for name in os.listdir(os.path.join("files", "image")) 179 | ]), 180 | inputs=[image_input], 181 | outputs=[image_output_slider], 182 | cache_examples=False, 183 | directory_name="examples_image", 184 | ) 185 | 186 | ### Image tab 187 | image_submit_btn.click( 188 | fn=process_image_check, 189 | inputs=image_input, 190 | outputs=None, 191 | preprocess=False, 192 | queue=False, 193 | ).success( 194 | fn=process_pipe_image, 195 | inputs=[ 196 | image_input, 197 | ], 198 | outputs=[image_output_slider], 199 | concurrency_limit=1, 200 | ) 201 | 202 | image_reset_btn.click( 203 | fn=lambda: ( 204 | None, 205 | None, 206 | None, 207 | ), 208 | inputs=[], 209 | outputs=[ 210 | image_input, 211 | image_output_slider, 212 | ], 213 | queue=False, 214 | ) 215 | 216 | ### Server launch 217 | 218 | demo.queue( 219 | api_open=False, 220 | ).launch( 221 | server_name="0.0.0.0", 222 | server_port=7860, 223 | ) 224 | 225 | 226 | def main(): 227 | os.system("pip freeze") 228 | 229 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 230 | 231 | weight_dtype = torch.float32 232 | pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0" 233 | pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1" 234 | revision = None 235 | variant = None 236 | 237 | # Load the model 238 | controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device) 239 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device) 240 | vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device) 241 | 242 | vae = AutoencoderKL.from_pretrained( 243 | pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant 244 | ).to(device) 245 | 246 | text_encoder = CLIPTextModel.from_pretrained( 247 | pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant 248 | ).to(device) 249 | tokenizer = AutoTokenizer.from_pretrained( 250 | pretrained_model_name_or_path2, 251 | subfolder="tokenizer", 252 | revision=revision, 253 | use_fast=False, 254 | ) 255 | pipe = DAIPipeline( 256 | vae=vae, 257 | text_encoder=text_encoder, 258 | tokenizer=tokenizer, 259 | unet=unet, 260 | controlnet=controlnet, 261 | safety_checker=None, 262 | scheduler=None, 263 | feature_extractor=None, 264 | t_start=0, 265 | ).to(device) 266 | 267 | try: 268 | import xformers 269 | pipe.enable_xformers_memory_efficient_attention() 270 | except: 271 | pass # run without xformers 272 | 273 | run_demo_server(pipe, vae_2) 274 | 275 | 276 | if __name__ == "__main__": 277 | main() 278 | -------------------------------------------------------------------------------- /files/image/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/files/image/1.png -------------------------------------------------------------------------------- /files/image/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/files/image/2.png -------------------------------------------------------------------------------- /files/image/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/files/image/3.png -------------------------------------------------------------------------------- /files/image/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/files/image/4.png -------------------------------------------------------------------------------- /files/image/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/files/image/5.png -------------------------------------------------------------------------------- /files/image/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/files/image/6.png -------------------------------------------------------------------------------- /files/image/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/files/image/7.png -------------------------------------------------------------------------------- /files/image/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/files/image/8.png -------------------------------------------------------------------------------- /input/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/input/1.png -------------------------------------------------------------------------------- /input/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/input/2.png -------------------------------------------------------------------------------- /input/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/input/3.png -------------------------------------------------------------------------------- /input/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/input/4.png -------------------------------------------------------------------------------- /input/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/input/5.png -------------------------------------------------------------------------------- /input/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/input/6.png -------------------------------------------------------------------------------- /input/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/input/7.png -------------------------------------------------------------------------------- /input/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abuuu122/Dereflection-Any-Image/3c350ab3d4e867df47539a3903fc64db44c44ade/input/8.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | gradio 3 | gradio_imageslider 4 | torch 5 | transformers 6 | pillow 7 | numpy 8 | xformers 9 | spaces 10 | accelerate 11 | imageio -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from DAI.pipeline_onestep import OneStepPipeline 5 | from DAI.controlnetvae import ControlNetVAEModel 6 | import numpy as np 7 | from diffusers import ( 8 | AutoencoderKL, 9 | ControlNetModel, 10 | DDPMScheduler, 11 | StableDiffusionControlNetPipeline, 12 | UNet2DConditionModel, 13 | UniPCMultistepScheduler, 14 | StableDiffusionPipeline 15 | ) 16 | from transformers import CLIPTextModel, AutoTokenizer 17 | from glob import glob 18 | import json 19 | import random 20 | from diffusers.utils import make_image_grid, load_image 21 | from peft import PeftModel 22 | from peft import LoraConfig, get_peft_model 23 | from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict 24 | 25 | from safetensors.torch import load_file 26 | 27 | 28 | from DAI.pipeline_all import DAIPipeline 29 | from DAI.decoder import CustomAutoencoderKL 30 | 31 | from tqdm import tqdm 32 | import argparse 33 | 34 | 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | 37 | weight_dtype = torch.float32 38 | model_dir = "./weights" 39 | pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0" 40 | pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1" 41 | revision = None 42 | variant = None 43 | # Load the model 44 | # normal 45 | controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device) 46 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device) 47 | vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device) 48 | 49 | 50 | # Load other components of the pipeline 51 | vae = AutoencoderKL.from_pretrained( 52 | pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant 53 | ).to(device) 54 | 55 | # import pdb; pdb.set_trace() 56 | text_encoder = CLIPTextModel.from_pretrained( 57 | pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant 58 | ).to(device) 59 | tokenizer = AutoTokenizer.from_pretrained( 60 | pretrained_model_name_or_path2, 61 | subfolder="tokenizer", 62 | revision=revision, 63 | use_fast=False, 64 | ) 65 | pipeline = DAIPipeline( 66 | vae=vae, 67 | text_encoder=text_encoder, 68 | tokenizer=tokenizer, 69 | unet=unet, 70 | controlnet=controlnet, 71 | safety_checker=None, 72 | scheduler=None, 73 | feature_extractor=None, 74 | t_start=0 75 | ).to(device) 76 | 77 | 78 | # Create a directory to save the results 79 | # Parse command line arguments 80 | parser = argparse.ArgumentParser(description="Run reflection removal on images.") 81 | parser.add_argument("--input_dir", type=str, required=True, help="Directory for evaluation inputs.") 82 | parser.add_argument("--result_dir", type=str, required=True, help="Directory for evaluation results.") 83 | parser.add_argument("--concat_dir", type=str, required=True, help="Directory for concat evaluation results.") 84 | 85 | args = parser.parse_args() 86 | 87 | input_dir = args.input_dir 88 | result_dir = args.result_dir 89 | concat_dir = args.concat_dir 90 | 91 | os.makedirs(result_dir, exist_ok=True) 92 | os.makedirs(concat_dir, exist_ok=True) 93 | 94 | input_files = sorted(glob(os.path.join(input_dir, "*"))) 95 | 96 | for input_file in tqdm(input_files, desc="Processing images"): 97 | input_image = load_image(input_file) 98 | 99 | resolution = 0 100 | if max(input_image.size) < 768: 101 | resolution = None 102 | result_image = pipeline( 103 | image=torch.tensor(np.array(input_image)).permute(2, 0, 1).float().div(255).unsqueeze(0).to(device), 104 | prompt="remove glass reflection", 105 | vae_2=vae_2, 106 | processing_resolution=resolution 107 | ).prediction[0] 108 | 109 | result_image = (result_image + 1) / 2 110 | result_image = result_image.clip(0., 1.) 111 | result_image = result_image * 255 112 | result_image = result_image.astype(np.uint8) 113 | result_image = Image.fromarray(result_image) 114 | 115 | concat_image = make_image_grid([input_image, result_image], rows=1, cols=2) 116 | 117 | # Save the concatenated image 118 | input_filename = os.path.basename(input_file) 119 | concat_image.save(os.path.join(concat_dir, f"{input_filename}")) 120 | result_image.save(os.path.join(result_dir, f"{input_filename}")) 121 | 122 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python run.py --input_dir ./input/ --result_dir ./result/ --concat_dir ./concat/ -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | 21 | # torchmetrics 22 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | --------------------------------------------------------------------------------