├── example ├── cat.png ├── flower.png └── metal.png ├── README.md └── inversion_editing_cli.py /example/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberDragon93/rf-inversion-diffuser/HEAD/example/cat.png -------------------------------------------------------------------------------- /example/flower.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberDragon93/rf-inversion-diffuser/HEAD/example/flower.png -------------------------------------------------------------------------------- /example/metal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CyberDragon93/rf-inversion-diffuser/HEAD/example/metal.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rf-inversion-diffuser 2 | 3 | Diffuser's implementation for https://rf-inversion.github.io 4 | 5 | Update: We suggest refer to this [link](https://github.com/lqiang67/rectified-flow/blob/main/examples/editing_flux_dev.ipynb) for easy implementation 6 | -------------------------------------------------------------------------------- /inversion_editing_cli.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import argparse 4 | import os 5 | 6 | from PIL import Image 7 | from diffusers import FluxPipeline 8 | from torch import Tensor 9 | from torchvision import transforms 10 | 11 | @torch.inference_mode() 12 | def decode_imgs(latents, pipeline): 13 | imgs = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor 14 | imgs = pipeline.vae.decode(imgs)[0] 15 | imgs = pipeline.image_processor.postprocess(imgs, output_type="pil") 16 | return imgs 17 | 18 | @torch.inference_mode() 19 | def encode_imgs(imgs, pipeline, DTYPE): 20 | latents = pipeline.vae.encode(imgs).latent_dist.sample() 21 | latents = (latents - pipeline.vae.config.shift_factor) * pipeline.vae.config.scaling_factor 22 | latents = latents.to(dtype=DTYPE) 23 | return latents 24 | 25 | def get_schedule( 26 | num_steps: int, 27 | image_seq_len: int, 28 | base_shift: float = 0.5, 29 | max_shift: float = 1.15, 30 | shift: bool = True, 31 | ) -> list: 32 | def get_lin_function( 33 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 34 | ): 35 | m = (y2 - y1) / (x2 - x1) 36 | b = y1 - m * x1 37 | return lambda x: m * x + b 38 | def time_shift(mu: float, sigma: float, t: Tensor): 39 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 40 | 41 | # extra step for zero 42 | timesteps = torch.linspace(1, 0, num_steps + 1, dtype=torch.float32) 43 | 44 | # shifting the schedule to favor high timesteps for higher signal images 45 | if shift: 46 | # estimate mu based on linear estimation between two points 47 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 48 | timesteps = time_shift(mu, 1.0, timesteps) 49 | 50 | return timesteps.tolist() 51 | 52 | @torch.inference_mode() 53 | def interpolated_inversion( 54 | pipeline, 55 | latents, 56 | gamma, 57 | DTYPE, 58 | num_steps=28, 59 | use_shift_t_sampling=False, 60 | ): 61 | timesteps = get_schedule( 62 | num_steps=num_steps, 63 | image_seq_len=(1024 // 16) * (1024 // 16), # vae_scale_factor = 16 64 | shift=use_shift_t_sampling, # Set True for Flux-dev, False for Flux-schnell 65 | )[::-1] # flipped for inversion 66 | prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( # null text 67 | prompt="", 68 | prompt_2="" 69 | ) 70 | latent_image_ids = pipeline._prepare_latent_image_ids( 71 | latents.shape[0], 72 | latents.shape[2], 73 | latents.shape[3], 74 | latents.device, 75 | DTYPE, 76 | ) 77 | packed_latents = pipeline._pack_latents( 78 | latents, 79 | batch_size=latents.shape[0], 80 | num_channels_latents=latents.shape[1], 81 | height=latents.shape[2], 82 | width=latents.shape[3], 83 | ) 84 | 85 | target_noise = torch.randn(packed_latents.shape, device=packed_latents.device, dtype=torch.float32) 86 | guidance_scale=0.0 # zero guidance for inversion 87 | guidance_vec = torch.full((packed_latents.shape[0],), guidance_scale, device=packed_latents.device, dtype=packed_latents.dtype) 88 | 89 | # Image inversion with interpolated velocity field. t goes from 0.0 to 1.0 90 | with pipeline.progress_bar(total=len(timesteps)-1) as progress_bar: 91 | for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): 92 | t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, device=packed_latents.device) 93 | 94 | # Null text velocity 95 | flux_velocity = pipeline.transformer( 96 | hidden_states=packed_latents, 97 | timestep=t_vec, 98 | guidance=guidance_vec, 99 | pooled_projections=pooled_prompt_embeds, 100 | encoder_hidden_states=prompt_embeds, 101 | txt_ids=text_ids, 102 | img_ids=latent_image_ids, 103 | joint_attention_kwargs=None, 104 | return_dict=pipeline, 105 | )[0] 106 | 107 | # Prevents precision issues 108 | packed_latents = packed_latents.to(torch.float32) 109 | flux_velocity = flux_velocity.to(torch.float32) 110 | 111 | # Target noise velocity 112 | target_noise_velocity = (target_noise - packed_latents) / (1.0 - t_curr) 113 | 114 | # interpolated velocity 115 | interpolated_velocity = gamma * target_noise_velocity + (1 - gamma) * flux_velocity 116 | 117 | # one step Euler 118 | packed_latents = packed_latents + (t_prev - t_curr) * interpolated_velocity 119 | 120 | packed_latents = packed_latents.to(DTYPE) 121 | progress_bar.update() 122 | 123 | print("Mean Absolute Error", torch.mean(torch.abs(packed_latents - target_noise))) 124 | 125 | latents = pipeline._unpack_latents( 126 | packed_latents, 127 | height=1024, 128 | width=1024, 129 | vae_scale_factor=pipeline.vae_scale_factor, 130 | ) 131 | latents = latents.to(DTYPE) 132 | return latents 133 | 134 | def generate_eta_values( 135 | timesteps, 136 | start_step, 137 | end_step, 138 | eta, 139 | eta_trend, 140 | ): 141 | assert start_step < end_step and start_step >= 0 and end_step <= len(timesteps), "Invalid start_step and end_step" 142 | # timesteps are monotonically decreasing, from 1.0 to 0.0 143 | eta_values = [0.0] * (len(timesteps) - 1) 144 | 145 | if eta_trend == 'constant': 146 | for i in range(start_step, end_step): 147 | eta_values[i] = eta 148 | elif eta_trend == 'linear_increase': 149 | total_time = timesteps[start_step] - timesteps[end_step - 1] 150 | for i in range(start_step, end_step): 151 | eta_values[i] = eta * (timesteps[start_step] - timesteps[i]) / total_time 152 | elif eta_trend == 'linear_decrease': 153 | total_time = timesteps[start_step] - timesteps[end_step - 1] 154 | for i in range(start_step, end_step): 155 | eta_values[i] = eta * (timesteps[i] - timesteps[end_step - 1]) / total_time 156 | else: 157 | raise NotImplementedError(f"Unsupported eta_trend: {eta_trend}") 158 | 159 | return eta_values 160 | 161 | @torch.inference_mode() 162 | def interpolated_denoise( 163 | pipeline, 164 | img_latents, 165 | eta_base, # base eta value 166 | eta_trend, # constant, linear_increase, linear_decrease 167 | start_step, # 0-based indexing, closed interval 168 | end_step, # 0-based indexing, open interval 169 | inversed_latents, # can be none if not using inversed latents 170 | use_inversed_latents=True, 171 | guidance_scale=3.5, 172 | prompt='photo of a tiger', 173 | DTYPE=torch.bfloat16, 174 | num_steps=28, 175 | use_shift_t_sampling=True, 176 | ): 177 | timesteps = get_schedule( 178 | num_steps=num_steps, 179 | image_seq_len=(1024 // 16) * (1024 // 16), # vae_scale_factor = 16 180 | shift=use_shift_t_sampling, 181 | ) 182 | prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( 183 | prompt=prompt, 184 | prompt_2=prompt 185 | ) 186 | latent_image_ids = pipeline._prepare_latent_image_ids( 187 | inversed_latents.shape[0], 188 | inversed_latents.shape[2], 189 | inversed_latents.shape[3], 190 | inversed_latents.device, 191 | DTYPE, 192 | ) 193 | if use_inversed_latents: 194 | packed_latents = pipeline._pack_latents( 195 | inversed_latents, 196 | batch_size=inversed_latents.shape[0], 197 | num_channels_latents=inversed_latents.shape[1], 198 | height=inversed_latents.shape[2], 199 | width=inversed_latents.shape[3], 200 | ) 201 | else: 202 | tmp_latents = torch.randn_like(img_latents) 203 | packed_latents = pipeline._pack_latents( 204 | tmp_latents, 205 | batch_size=tmp_latents.shape[0], 206 | num_channels_latents=tmp_latents.shape[1], 207 | height=tmp_latents.shape[2], 208 | width=tmp_latents.shape[3], 209 | ) 210 | 211 | packed_img_latents = pipeline._pack_latents( 212 | img_latents, 213 | batch_size=img_latents.shape[0], 214 | num_channels_latents=img_latents.shape[1], 215 | height=img_latents.shape[2], 216 | width=img_latents.shape[3], 217 | ) 218 | 219 | target_img = packed_img_latents.clone().to(torch.float32) 220 | guidance_vec = torch.full((packed_latents.shape[0],), guidance_scale, device=packed_latents.device, dtype=packed_latents.dtype) 221 | 222 | eta_values = generate_eta_values(timesteps, start_step, end_step, eta_base, eta_trend) 223 | 224 | # Denoising with interpolated velocity field. t goes from 1.0 to 0.0 225 | with pipeline.progress_bar(total=len(timesteps)-1) as progress_bar: 226 | for idx, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): 227 | t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, device=packed_latents.device) 228 | 229 | # Editing text velocity 230 | flux_velocity = pipeline.transformer( 231 | hidden_states=packed_latents, 232 | timestep=t_vec, 233 | guidance=guidance_vec, 234 | pooled_projections=pooled_prompt_embeds, 235 | encoder_hidden_states=prompt_embeds, 236 | txt_ids=text_ids, 237 | img_ids=latent_image_ids, 238 | joint_attention_kwargs=None, 239 | return_dict=pipeline, 240 | )[0] 241 | 242 | # Prevents precision issues 243 | packed_latents = packed_latents.to(torch.float32) 244 | flux_velocity = flux_velocity.to(torch.float32) 245 | 246 | # Target image velocity 247 | target_img_velocity = -(target_img - packed_latents) / t_curr 248 | 249 | # interpolated velocity 250 | eta = eta_values[idx] 251 | interpolated_velocity = eta * target_img_velocity + (1 - eta) * flux_velocity 252 | packed_latents = packed_latents + (t_prev - t_curr) * interpolated_velocity 253 | print(f"X_{t_prev:.3f} = X_{t_curr:.3f} + {t_prev - t_curr:.3f} * ({eta:.3f} * target_img_velocity + {1 - eta:.3f} * flux_velocity)") 254 | 255 | packed_latents = packed_latents.to(DTYPE) 256 | progress_bar.update() 257 | 258 | latents = pipeline._unpack_latents( 259 | packed_latents, 260 | height=1024, 261 | width=1024, 262 | vae_scale_factor=pipeline.vae_scale_factor, 263 | ) 264 | latents = latents.to(DTYPE) 265 | return latents 266 | 267 | @torch.inference_mode() 268 | def main(): 269 | parser = argparse.ArgumentParser(description='Test interpolated_denoise with different parameters.') 270 | parser.add_argument('--model_path', type=str, default='/root/autodl-tmp/Flux-dev', help='Path to the pretrained model') 271 | parser.add_argument('--image_path', type=str, default='./example/cat.png', help='Path to the input image') 272 | parser.add_argument('--output_dir', type=str, default='outputs', help='Directory to save output images') 273 | parser.add_argument('--eta_base', type=float, default=1.0, help='Eta parameter for interpolated_denoise') 274 | parser.add_argument('--eta_trend', type=str, default='linear_decrease', choices=['constant', 'linear_increase', 'linear_decrease'], help='Eta trend for interpolated_denoise') 275 | parser.add_argument('--start_step', type=int, default=0, help='Start step for eta values, 0-based indexing, closed interval') 276 | parser.add_argument('--end_step', type=int, default=7, help='End step for eta values, 0-based indexing, open interval') 277 | parser.add_argument('--use_inversed_latents', action='store_true', help='Use inversed latents') 278 | parser.add_argument('--guidance_scale', type=float, default=3.5, help='Guidance scale for interpolated_denoise') 279 | parser.add_argument('--num_steps', type=int, default=28, help='Number of steps for timesteps') 280 | parser.add_argument('--shift', action='store_true', help='Use shift in get_schedule') 281 | parser.add_argument('--gamma', type=float, default=0.5, help='Gamma parameter for interpolated_inversion') 282 | parser.add_argument('--prompt', type=str, default='photo of a tiger', help='Prompt text for generation') 283 | parser.add_argument('--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16', 'float32'], help='Data type for computations') 284 | 285 | args = parser.parse_args() 286 | 287 | if args.dtype == 'bfloat16': 288 | DTYPE = torch.bfloat16 289 | elif args.dtype == 'float16': 290 | DTYPE = torch.float16 291 | elif args.dtype == 'float32': 292 | DTYPE = torch.float32 293 | else: 294 | raise ValueError(f"Unsupported dtype: {args.dtype}") 295 | 296 | device = "cuda" if torch.cuda.is_available() else "cpu" 297 | pipe = FluxPipeline.from_pretrained(args.model_path, torch_dtype=DTYPE) 298 | pipe.to(device) 299 | 300 | # Create output directory if not exists 301 | os.makedirs(args.output_dir, exist_ok=True) 302 | 303 | # Load and preprocess the image 304 | img = Image.open(args.image_path) 305 | 306 | train_transforms = transforms.Compose( 307 | [ 308 | transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR), 309 | transforms.CenterCrop(1024), 310 | transforms.ToTensor(), 311 | transforms.Normalize([0.5], [0.5]), 312 | ] 313 | ) 314 | 315 | img = train_transforms(img).unsqueeze(0).to(device).to(DTYPE) 316 | 317 | # Encode image to latents 318 | img_latent = encode_imgs(img, pipe, DTYPE) 319 | 320 | if args.use_inversed_latents: 321 | inversed_latent = interpolated_inversion(pipe, img_latent, gamma=args.gamma, DTYPE=DTYPE, num_steps=args.num_steps, use_shift_t_sampling=False) 322 | else: 323 | inversed_latent = None 324 | 325 | # Denoise 326 | img_latents = interpolated_denoise( 327 | pipe, 328 | img_latent, 329 | eta_base=args.eta_base, 330 | eta_trend=args.eta_trend, 331 | start_step=args.start_step, 332 | end_step=args.end_step, 333 | inversed_latents=inversed_latent, 334 | use_inversed_latents=args.use_inversed_latents, 335 | guidance_scale=args.guidance_scale, 336 | prompt=args.prompt, 337 | DTYPE=DTYPE, 338 | use_shift_t_sampling=args.shift, 339 | ) 340 | 341 | # Decode latents to images 342 | out = decode_imgs(img_latents, pipe)[0] 343 | 344 | # Save output image 345 | output_filename = f"eta{args.eta_base}_{args.eta_trend}_start{args.start_step}_end{args.end_step}_inversed{args.use_inversed_latents}_guidance{args.guidance_scale}.png" 346 | output_path = os.path.join(args.output_dir, output_filename) 347 | out.save(output_path) 348 | print(f"Saved output image to {output_path} with parameters: eta_base={args.eta_base}, start_step={args.start_step}, end_step={args.end_step}, guidance_scale={args.guidance_scale}") 349 | 350 | if __name__ == "__main__": 351 | main() 352 | --------------------------------------------------------------------------------