├── LICENSE ├── README.md ├── Slicedit ├── inversion_utils.py ├── slicedit_attention_utils.py └── video_utils.py ├── Videos └── parkour.mp4 ├── main.py ├── requirements.txt └── yaml_files ├── dataset_configs └── parkour.yaml └── exp_configs └── default_exp_params.yaml /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nathaniel Cohen, Vladimir Kulikov, Matan Kleiner, Inbar Huberman-Spiegelglas, Tomer Michaeli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Zero-Shot Video Editing](https://img.shields.io/badge/zero%20shot-video%20editing-Green)](https://github.com/topics/video-editing) 2 | [![Python](https://img.shields.io/badge/python-3.8+-blue?python-3670A0?style=for-the-badge&logo=python&logoColor=ffdd54)](https://www.python.org/downloads/release/python-38/) 3 | ![PyTorch](https://img.shields.io/badge/torch-2.0.0-red?PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white) 4 | 5 | # Slicedit 6 | 7 | [Project](https://matankleiner.github.io/slicedit/) | [Arxiv](https://arxiv.org/abs/2405.12211) | [Proceedings](https://proceedings.mlr.press/v235/cohen24a.html) 8 | ### [ICML 2024] Official pytorch implementation of the paper: "Slicedit: Zero-Shot Video Editing With Text-to-Image Diffusion Models Using Spatio-Temporal Slices" 9 | 10 | https://github.com/fallenshock/Slicedit/assets/63591190/a92eebce-d276-4bef-a167-3aa272fb58ca 11 | 12 | ## Installation 13 | 1. Clone the repository 14 | 15 | 2. Install the required dependencies: `pip install -r requirements.txt` 16 | * Tested with CUDA version 12.0 and diffusers 0.21.2 17 | ## Usage 18 | 1. Place desired input videos into `Videos` folder 19 | 20 | 2. Place desired dataset config yaml file into `yaml_files/dataset_configs` 21 | 22 | 3. Change experiment config .yaml if desired in `yaml_files/exp_configs` 23 | 24 | Note: The dataset config specifies the video name, source prompt and target prompt(s). 25 | Experiment configs specify the hyperparameters for the run. Use the provided default yamls as reference. 26 | 27 | 4. Run `python main.py --dataset_yaml ` 28 | * Optional: passing `--use_negative_tar_prompt` improves sharpness. 29 | 30 | ## License 31 | This project is licensed under the [MIT License](LICENSE). 32 | 33 | 34 | ### Citation 35 | If you use this code for your research, please cite our paper: 36 | 37 | ``` 38 | @InProceedings{cohen2024slicedit, 39 | title={Slicedit: Zero-Shot Video Editing With Text-to-Image Diffusion Models Using Spatio-Temporal Slices}, 40 | author={Cohen, Nathaniel and Kulikov, Vladimir and Kleiner, Matan and Huberman-Spiegelglas, Inbar and Michaeli, Tomer}, 41 | booktitle={Proceedings of the 41st International Conference on Machine Learning}, 42 | pages={9109--9137}, 43 | year={2024}, 44 | volume={235}, 45 | series={Proceedings of Machine Learning Research}, 46 | month={21--27 Jul}, 47 | publisher={PMLR}, 48 | url={https://proceedings.mlr.press/v235/cohen24a.html}, 49 | ``` 50 | -------------------------------------------------------------------------------- /Slicedit/inversion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from tqdm import tqdm 4 | from diffusers import StableDiffusionPipeline 5 | from Slicedit.video_utils import image_grid 6 | import torch.nn as nn 7 | from torch.cuda.amp import autocast 8 | import numpy as np 9 | from Slicedit.slicedit_attention_utils import register_n_keyframes, register_extended_attention, register_conv_injection, register_time, register_denoiser_xy 10 | from Slicedit.video_utils import create_video_from_frames, save_video 11 | from PIL import Image 12 | from math import ceil as ceil 13 | import torch.nn.functional as F 14 | from einops import rearrange 15 | 16 | 17 | def sample_xts_from_x0_and_eps(model, x0, num_inference_steps=50, eps = None): 18 | """ 19 | Samples from P(x_1:T|x_0) 20 | """ 21 | # torch.manual_seed(43256465436) 22 | 23 | alpha_bar = model.scheduler.alphas_cumprod 24 | sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5 25 | variance_noise_shape = ( 26 | num_inference_steps, 27 | model.unet.in_channels, 28 | model.unet.sample_size, 29 | model.unet.sample_size) 30 | 31 | timesteps = model.scheduler.timesteps.to(model.device) 32 | t_to_idx = {int(v):k for k,v in enumerate(timesteps)} 33 | 34 | xts = torch.zeros(variance_noise_shape).to(x0.device) 35 | for t in reversed(timesteps): 36 | idx = t_to_idx[int(t)] 37 | if eps is None: 38 | xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t] 39 | else: 40 | xts[idx] = x0 * (alpha_bar[t] ** 0.5) + eps[idx] * sqrt_one_minus_alpha_bar[t] 41 | xts = torch.cat([xts, x0 ],dim = 0) 42 | 43 | 44 | return xts 45 | 46 | 47 | def encode_text(model, prompts): 48 | text_input = model.tokenizer( 49 | prompts, 50 | padding="max_length", 51 | max_length=model.tokenizer.model_max_length, 52 | truncation=True, 53 | return_tensors="pt", 54 | ) 55 | with torch.no_grad(): 56 | text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0] 57 | return text_encoding 58 | 59 | 60 | def get_variance(model, timestep): 61 | prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps 62 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep] 63 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod 64 | beta_prod_t = 1 - alpha_prod_t 65 | beta_prod_t_prev = 1 - alpha_prod_t_prev 66 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) 67 | return variance 68 | 69 | 70 | def calc_mu(model, model_output, timestep, sample, eta = 0): 71 | # get previous step value (=t-1) 72 | prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps 73 | # compute alphas, betas 74 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep] 75 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod 76 | beta_prod_t = 1 - alpha_prod_t 77 | # compute predicted original sample from predicted noise also called 78 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 79 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 80 | # compute variance: "sigma_t(η)" -> see formula (16) 81 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 82 | # variance = self.scheduler._get_variance(timestep, prev_timestep) 83 | variance = get_variance(model, timestep) #, prev_timestep) 84 | # std_dev_t = eta * variance ** (0.5) 85 | # Take care of asymetric reverse process (asyrp) 86 | model_output_direction = model_output 87 | # compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 88 | # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction 89 | 90 | pred_sample_direction = (1 - alpha_prod_t_prev - eta*variance) ** (0.5) * model_output_direction 91 | # compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 92 | mu = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 93 | # add noise if eta > 0 94 | 95 | return mu,variance 96 | 97 | 98 | def slicedit_ddpm_inversion_loop(model, x0_video, 99 | etas = None, 100 | prog_bar = False, 101 | src_prompt = "", 102 | tar_prompt = "", 103 | prompt_time = "", 104 | cfg_scale = 3.5, 105 | alpha=0.5, 106 | num_inference_steps=50, 107 | eps = None, 108 | n_frames = 64, 109 | skip=0, 110 | cfg_scale_tar=1, 111 | save_path="", 112 | orig_frame_size=(512,512), 113 | qk_injection_t=-1, 114 | conv_injection_t=-1, 115 | negative_time_prompt="", 116 | cfg_scale_time=1, 117 | frame_batch_size=6, 118 | use_negative_tar_prompt=False, 119 | x_and_y=False): 120 | 121 | # xts for source and target frames for all timesteps shape:[T_diff, src/tar, n_frames, 4, 64, 64] 122 | xts = torch.zeros((num_inference_steps+1,2,n_frames,4,64,64),device=model.device) 123 | # noisy volume for source and target frames for current timestep shape:[src/tar, n_frames, 4, 64, 64] 124 | volume_t = torch.zeros(2,n_frames, 4, 64, 64).to(model.device) 125 | # noise predictions for source and target frames for current timestep shape:[src/tar, n_frames, 4, 64, 64] 126 | noise_pred_xy = torch.zeros(2, n_frames, 4, 64, 64).to(model.device) 127 | 128 | # embed source and target text prompts 129 | if not src_prompt=="": 130 | text_embeddings = encode_text(model, src_prompt) 131 | text_embeddings_time = encode_text(model, prompt_time) 132 | if not tar_prompt=="": 133 | tar_emb = encode_text(model, tar_prompt) 134 | 135 | ## negative target prompt 136 | if use_negative_tar_prompt: 137 | negative_tar_prompt = "ugly, blurry, low res" 138 | else: 139 | negative_tar_prompt = "" 140 | neg_target_embedding = encode_text(model, negative_tar_prompt) 141 | 142 | 143 | negative_time_embedding = encode_text(model, negative_time_prompt) 144 | uncond_embedding = encode_text(model, "") 145 | 146 | # get timesteps 147 | timesteps = model.scheduler.timesteps.to(model.device) 148 | 149 | if type(etas) in [int, float]: etas = [etas]*(model.scheduler.num_inference_steps) 150 | ## we can specify a value of epsilon in order to be able to use a 151 | ## fixed epsilon for all the frames of a video 152 | for i in range(n_frames): 153 | xts[:,0,i,:,:,:] = sample_xts_from_x0_and_eps(model, x0_video[i], num_inference_steps=num_inference_steps, eps = eps) 154 | # initialize edit xts with source xts 155 | xts[:,1,i,:,:,:] = xts[:,0,i,:,:,:] 156 | 157 | 158 | if qk_injection_t == -1: 159 | qk_injection_t = num_inference_steps-skip 160 | else: 161 | qk_injection_t = (num_inference_steps-skip) * qk_injection_t // 100 162 | if conv_injection_t == -1: 163 | conv_injection_t = num_inference_steps-skip 164 | else: 165 | conv_injection_t = (num_inference_steps-skip) * conv_injection_t // 100 166 | qk_injection_timesteps = timesteps[skip:skip+qk_injection_t] if qk_injection_t >= 0 else [] 167 | conv_injection_timesteps = timesteps[skip:skip+conv_injection_t] if conv_injection_t >= 0 else [] 168 | 169 | register_extended_attention(model, qk_injection_timesteps) 170 | register_conv_injection(model, conv_injection_timesteps) 171 | 172 | 173 | n_keyframes = 3 174 | register_n_keyframes(model, n_keyframes) 175 | 176 | t_to_idx = {int(v):k for k,v in enumerate(timesteps[-(num_inference_steps-skip):])} 177 | op = tqdm(timesteps[-(num_inference_steps-skip):]) if prog_bar else timesteps[-(num_inference_steps-skip):] 178 | 179 | # initialize volume_t with source xts at timestep T-skip 180 | volume_t[1, :, :,:,:] = xts[skip,0,:,:,:,:] # .permute(1,2,0) 181 | 182 | for k,t in enumerate(op): 183 | 184 | idx = t_to_idx[int(t)] 185 | 186 | # noise_pred_xy holds both noise predictions for source and target 187 | noise_pred_xy = pred_noise_xy_func(xts, t, idx, uncond_embedding, neg_target_embedding, text_embeddings, model, cfg_scale, frame_batch_size, tar_emb, cfg_scale_tar, skip) 188 | volume_t[0, :, :,:,:] = xts[skip+idx,0,:,:,:,:] # .permute(1,2,0) 189 | # spatiotemporal denoise 190 | if alpha > 0: 191 | st_batch_size = 8 # reduce if having memory issues 192 | noise_pred_t = temporal_volume_denoise(model, n_frames, volume_t, t, negative_time_embedding, text_embeddings_time, st_batch_size, cfg_scale_time, x_and_y=x_and_y) 193 | # combine noise predictions 194 | noise_pred_src = (np.sqrt(alpha))*noise_pred_t[0] + (np.sqrt(1-alpha))*noise_pred_xy[0] 195 | noise_pred_tar = (np.sqrt(alpha))*noise_pred_t[1] + (np.sqrt(1-alpha))*noise_pred_xy[1] 196 | else: 197 | noise_pred_src = noise_pred_xy[0] 198 | noise_pred_tar = noise_pred_xy[1] 199 | 200 | ## Edit Friendly DDPM inversion step https://arxiv.org/abs/2304.06140 201 | xtm1_src = xts[skip+idx+1,0,:,:,:,:] 202 | xt_src = xts[skip+idx,0,:,:,:,:] 203 | 204 | mu_xt_src,variance = calc_mu(model, noise_pred_src, t, xt_src,eta=etas[idx]) 205 | z_src = (xtm1_src - mu_xt_src ) / ( etas[idx] * variance ** 0.5 ) 206 | 207 | if k==num_inference_steps-1: 208 | z_src = torch.zeros_like(z_src) 209 | 210 | mu_xt_tar,variance = calc_mu(model, noise_pred_tar, t,volume_t[1],eta=etas[idx]) 211 | sigma_z = etas[idx] * variance ** (0.5) * z_src 212 | xt = mu_xt_tar + sigma_z 213 | volume_t[1] = xt 214 | xts[skip+k+1,1,:,:,:,:] = xt 215 | 216 | # error accumulation fix (as proposed in Edit Friendly DDPM Inversion) 217 | xtm1 = mu_xt_src + ( etas[idx] * variance ** 0.5 ) * z_src 218 | xts[skip+idx+1, 0, :, :,:,:] = xtm1 219 | 220 | ## end of diffusion loop. 221 | 222 | decoded_frames = [] 223 | for n in range(n_frames): 224 | w0 = volume_t[1,n,:,:,:].expand(1, -1, -1, -1) 225 | # vae decode image 226 | with torch.autocast("cuda"), torch.inference_mode(): 227 | x0_dec = model.vae.decode(1 / 0.18215 * w0).sample 228 | decoded_frames.append(x0_dec) 229 | img = image_grid(x0_dec) 230 | 231 | # resize images to original size and save them to disk as png files 232 | img = img.resize((orig_frame_size[1], orig_frame_size[0]), resample=Image.LANCZOS) 233 | image_name_png = f'frames/cfg_d_{[cfg_scale_tar][0]}_' + f'skip_{skip}_' + str(n) + ".png" 234 | 235 | os.makedirs(save_path+'/frames' , exist_ok=True) 236 | save_full_path = os.path.join(save_path, image_name_png) 237 | img.save(save_full_path) 238 | 239 | create_video_from_frames(save_path+'/frames') 240 | ## Uncomment the following lines to generate a compressed video without ffmpeg 241 | decoded_frames = torch.stack(decoded_frames) 242 | save_video(raw_frames=decoded_frames, save_path=os.path.join(save_path, 'slicedit_video_compressed.mp4'), fps=25, orig_size=orig_frame_size[:-1]) 243 | 244 | 245 | def pred_noise_xy_func(xts, 246 | t, 247 | idx, 248 | uncond_embedding, 249 | neg_target_embedding, 250 | text_embeddings, 251 | model, 252 | cfg_scale, 253 | frame_batch_size, 254 | tar_emb=None, 255 | cfg_scale_tar=1, 256 | skip=0, 257 | n_keyframes=3): 258 | 259 | # process both xt src and xt tar 260 | noise_pred_xy = torch.zeros(xts.shape[1], xts.shape[2], 4, 64, 64).to(model.device) 261 | n_frames = xts.shape[2] 262 | 263 | for n in range(ceil(n_frames/frame_batch_size)): 264 | batch_upper_idx = min((n+1)*frame_batch_size, n_frames) 265 | cur_batch_size = batch_upper_idx - n*frame_batch_size 266 | if n_keyframes == 3: 267 | # global frames 268 | idx1 = n_frames*3//6 269 | 270 | # add local frame indices 271 | if cur_batch_size == frame_batch_size: 272 | ref_frame_indices = [idx1, n*frame_batch_size + 1, n*frame_batch_size + 4] 273 | # last batch in video smaller than frame_batch_size 274 | else: 275 | ref_frame_indices = [idx1, n*frame_batch_size, n*frame_batch_size] 276 | else: 277 | raise NotImplementedError 278 | 279 | # xts contain both original and edited frames 280 | xt_src = xts[skip+idx,0, n*frame_batch_size:batch_upper_idx,:,:,:] 281 | xt_tar = xts[skip+idx,1, n*frame_batch_size:batch_upper_idx,:,:,:] 282 | 283 | # compute text embeddings 284 | text_embed_input = torch.cat([uncond_embedding.repeat(cur_batch_size+n_keyframes, 1, 1), 285 | text_embeddings.repeat(cur_batch_size+n_keyframes, 1, 1), 286 | neg_target_embedding.repeat(cur_batch_size+n_keyframes, 1, 1), 287 | tar_emb.repeat(cur_batch_size+n_keyframes, 1, 1)]) 288 | 289 | with torch.no_grad(): 290 | ref_frames_src = torch.cat([xts[skip+idx,0, ref_frame_indices[i],:,:,:].unsqueeze(0) for i in range(n_keyframes)], dim=0) 291 | ref_frames_tar = torch.cat([xts[skip+idx,1, ref_frame_indices[i],:,:,:].unsqueeze(0) for i in range(n_keyframes)], dim=0) 292 | 293 | # register time for the injection schedule 294 | register_time(model, t.item()) 295 | register_denoiser_xy(model, is_xy=True) 296 | # feed denoiser with a batch of current and reference frames for EA processing x4 (src_unc, src_cond, tar_unc, tar_cond) 297 | noise_pred = model.unet.forward(torch.cat(([xt_src]+[ref_frames_src]) * 2 + ([xt_tar]+[ref_frames_tar]) * 2), timestep = t, encoder_hidden_states = text_embed_input) 298 | # set denoiser to normal mode 299 | register_denoiser_xy(model, is_xy=False) 300 | 301 | # perform guidance 302 | noise_pred_src_uncond, noise_pred_src_cond, noise_pred_tar_uncond, noise_pred_tar_cond = noise_pred['sample'].chunk(4) 303 | noise_pred_src = noise_pred_src_uncond + cfg_scale * (noise_pred_src_cond - noise_pred_src_uncond) 304 | noise_pred_tar = noise_pred_tar_uncond + cfg_scale_tar * (noise_pred_tar_cond - noise_pred_tar_uncond) 305 | 306 | # store source and target noise predictions 307 | noise_pred_xy[0, n*frame_batch_size:batch_upper_idx, :,:,:] = noise_pred_src[:cur_batch_size] 308 | noise_pred_xy[1, n*frame_batch_size:batch_upper_idx, :,:,:] = noise_pred_tar[:cur_batch_size] 309 | 310 | return noise_pred_xy 311 | 312 | 313 | def temporal_volume_denoise(model, n_frames, volume_t, t, negative_time_embedding, text_embeddings_time, st_batch_size, cfg_scale_time, x_and_y=False): 314 | noise_pred_t = torch.zeros(2, n_frames, 4, 64, 64).to(model.device) 315 | # calculate number of size 64 temporal windows in the video s.t. there is atleast 1 overlapping frame between them 316 | n_slices = ceil((n_frames-1) / 63) 317 | # calculate the overlaps between windows 318 | if n_slices > 1: 319 | # overlap = (n_slices*64-n_frames)//n_slices # 320 | overlap = (n_slices*64-n_frames)//(n_slices-1) 321 | # overlap_last = overlap+(n_slices*64-n_frames)%n_slices # 322 | overlap_last = overlap+(n_slices*64-n_frames)%(n_slices-1) 323 | overlap_list = [overlap]*(n_slices-2) + [overlap_last] 324 | # calculate the start indices of each window 325 | cumsum_overlap = np.cumsum(overlap_list) 326 | # sl_idxs = [0]+[i*64-i*overlap_list[i] for i in range(0, n_slices-1)] # 327 | sl_idxs = [0]+[(i+1)*64-cumsum_overlap[i] for i in range(0, n_slices-1)] 328 | 329 | else: 330 | overlap = 0 331 | overlap_last = 0 332 | overlap_list = [0] 333 | sl_idxs = [0] 334 | 335 | for idx, sl_i in enumerate(sl_idxs): 336 | for i in range(64//st_batch_size): 337 | 338 | xt = rearrange(volume_t[0,sl_i:sl_i+64, :, :, i*st_batch_size:(i+1)*st_batch_size], 'n c x y -> y c n x') 339 | noise_pred = temporal_denoise(model, xt, t, negative_time_embedding, text_embeddings_time, st_batch_size, cfg_scale_time) 340 | noise_pred_t[0,sl_i:sl_i+64,:,:,i*st_batch_size:(i+1)*st_batch_size] += rearrange(noise_pred, 'y c n x -> n c x y') 341 | 342 | xt = rearrange(volume_t[1,sl_i:sl_i+64, :, :, i*st_batch_size:(i+1)*st_batch_size], 'n c x y -> y c n x') 343 | noise_pred = temporal_denoise(model, xt, t, negative_time_embedding, text_embeddings_time, st_batch_size, cfg_scale_time) 344 | noise_pred_t[1,sl_i:sl_i+64,:,:,i*st_batch_size:(i+1)*st_batch_size] += rearrange(noise_pred, 'y c n x -> n c x y') 345 | 346 | # normalize noise levels in overlapping frames 347 | if sl_i > 0: 348 | noise_pred_t[:,sl_i:sl_i+overlap_list[idx-1],:,:,:] = (noise_pred_t[:,sl_i:sl_i+overlap_list[idx-1],:,:,:])*np.sqrt(1/2) 349 | 350 | if x_and_y: # add the x_t_slice noise prediction to the existing x_y_slice noise prediction 351 | volume_t_p = rearrange(volume_t, 's n c x y -> s n c y x').permute(0,1,3,2,4) 352 | noise_pred_t_x = torch.zeros_like(volume_t_p) 353 | 354 | for idx, sl_i in enumerate(sl_idxs): 355 | for i in range(64//st_batch_size): 356 | xt = rearrange(volume_t_p[0,sl_i:sl_i+64, :, :, i*st_batch_size:(i+1)*st_batch_size], 'n c x y -> y c n x') 357 | noise_pred = temporal_denoise(model, xt, t, negative_time_embedding, text_embeddings_time, st_batch_size, cfg_scale_time) 358 | noise_pred_t_x[0,sl_i:sl_i+64,:,:,i*st_batch_size:(i+1)*st_batch_size] += rearrange(noise_pred, 'y c n x -> n c x y') 359 | 360 | 361 | xt = rearrange(volume_t_p[1,sl_i:sl_i+64, :, :, i*st_batch_size:(i+1)*st_batch_size], 'n c x y -> y c n x') 362 | noise_pred = temporal_denoise(model, xt, t, negative_time_embedding, text_embeddings_time, st_batch_size, cfg_scale_time) 363 | noise_pred_t_x[1,sl_i:sl_i+64,:,:,i*st_batch_size:(i+1)*st_batch_size] += rearrange(noise_pred, 'y c n x -> n c x y') 364 | # normalize noise levels in overlapping frames 365 | if sl_i > 0: 366 | noise_pred_t_x[:,sl_i:sl_i+overlap_list[idx-1],:,:,:] = (noise_pred_t_x[:,sl_i:sl_i+overlap_list[idx-1],:,:,:])*np.sqrt(1/2) 367 | 368 | noise_pred_t_x = rearrange(noise_pred_t_x, 's n c y x -> s n c x y') 369 | 370 | noise_pred_t = noise_pred_t + noise_pred_t_x # add the x_t_slice noise prediction to the existing x_y_slice noise prediction 371 | noise_pred_t = noise_pred_t * np.sqrt(1/2) # normalize noise levels between x and y slices 372 | 373 | return noise_pred_t 374 | 375 | 376 | def temporal_denoise(model, xt, t, neg_embedding, text_embeddings_time, st_batch_size, cfg_scale): 377 | 378 | if cfg_scale == 1: 379 | with torch.no_grad(): 380 | cond_out = model.unet.forward(xt, timestep = t, encoder_hidden_states = text_embeddings_time.repeat(st_batch_size, 1, 1)) 381 | return cond_out.sample 382 | 383 | ## Unconditional embedding 384 | with torch.no_grad(): 385 | uncond_out = model.unet.forward(xt, timestep = t, encoder_hidden_states = neg_embedding.repeat(st_batch_size, 1, 1)) 386 | 387 | ## Conditional embedding 388 | with torch.no_grad(): 389 | cond_out = model.unet.forward(xt, timestep = t, encoder_hidden_states = text_embeddings_time.repeat(st_batch_size, 1, 1)) 390 | ## classifier free guidance 391 | noise_pred = uncond_out.sample + cfg_scale * (cond_out.sample - uncond_out.sample) 392 | 393 | return noise_pred 394 | -------------------------------------------------------------------------------- /Slicedit/slicedit_attention_utils.py: -------------------------------------------------------------------------------- 1 | ## This file contains functions which were based upon the code from TokenFlow: https://github.com/omerbt/TokenFlow ## 2 | 3 | from typing import Type 4 | import torch 5 | 6 | def isinstance_str(x: object, cls_name: str): 7 | """ 8 | Checks whether x has any class *named* cls_name in its ancestry. 9 | Doesn't require access to the class's implementation. 10 | 11 | Useful for patching! 12 | """ 13 | 14 | for _cls in x.__class__.__mro__: 15 | if _cls.__name__ == cls_name: 16 | return True 17 | 18 | return False 19 | 20 | # This function puts the denoiser into Extended Attention mode (for x_y slice denoising) 21 | def register_denoiser_xy(diffusion_model, is_xy): 22 | down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]} 23 | up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} 24 | for res in up_res_dict: 25 | for block in up_res_dict[res]: 26 | module = diffusion_model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1 27 | setattr(module, 'is_xy', is_xy) 28 | 29 | module = diffusion_model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn2 30 | setattr(module, 'is_xy', is_xy) 31 | 32 | for res in down_res_dict: 33 | for block in down_res_dict[res]: 34 | module = diffusion_model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1 35 | setattr(module, 'is_xy', is_xy) 36 | 37 | module = diffusion_model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn2 38 | setattr(module, 'is_xy', is_xy) 39 | 40 | module = diffusion_model.unet.mid_block.attentions[0].transformer_blocks[0].attn1 41 | setattr(module, 'is_xy', is_xy) 42 | 43 | module = diffusion_model.unet.mid_block.attentions[0].transformer_blocks[0].attn2 44 | setattr(module, 'is_xy', is_xy) 45 | 46 | 47 | def register_time(model, t): 48 | conv_module = model.unet.up_blocks[1].resnets[1] 49 | setattr(conv_module, 't', t) 50 | 51 | down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]} 52 | up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} 53 | for res in up_res_dict: 54 | for block in up_res_dict[res]: 55 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1 56 | setattr(module, 't', t) 57 | 58 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn2 59 | setattr(module, 't', t) 60 | 61 | for res in down_res_dict: 62 | for block in down_res_dict[res]: 63 | module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1 64 | setattr(module, 't', t) 65 | 66 | module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn2 67 | setattr(module, 't', t) 68 | 69 | module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1 70 | setattr(module, 't', t) 71 | 72 | module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn2 73 | setattr(module, 't', t) 74 | 75 | 76 | def register_n_keyframes(model, n_keyframes): 77 | 78 | down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]} 79 | up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} 80 | for res in up_res_dict: 81 | for block in up_res_dict[res]: 82 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1 83 | setattr(module, 'n_keyframes', n_keyframes) 84 | 85 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn2 86 | setattr(module, 'n_keyframes', n_keyframes) 87 | 88 | for res in down_res_dict: 89 | for block in down_res_dict[res]: 90 | module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1 91 | setattr(module, 'n_keyframes', n_keyframes) 92 | 93 | module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn2 94 | setattr(module, 'n_keyframes', n_keyframes) 95 | 96 | module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1 97 | setattr(module, 'n_keyframes', n_keyframes) 98 | 99 | module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn2 100 | setattr(module, 'n_keyframes', n_keyframes) 101 | 102 | 103 | def prepare_extended_QKV(self, token_mat, n_keyframes, n_orig_frames, sequence_length, h, dim, is_SA=False): 104 | if not is_SA: 105 | token_mat = token_mat.reshape(1, n_keyframes * sequence_length, -1).repeat(n_orig_frames, 1, 1) 106 | 107 | token_mat = self.head_to_batch_dim(token_mat) 108 | token_mat = token_mat.view(n_orig_frames, h, sequence_length * n_keyframes, dim // h) 109 | return token_mat 110 | 111 | 112 | def register_extended_attention(model, injection_schedule): 113 | def sa_forward(self): 114 | to_out = self.to_out 115 | if type(to_out) is torch.nn.modules.container.ModuleList: 116 | to_out = self.to_out[0] 117 | else: 118 | to_out = self.to_out 119 | 120 | def forward(x, encoder_hidden_states=None, attention_mask=None): 121 | 122 | batch_size, sequence_length, dim = x.shape 123 | h = self.heads 124 | if self.is_xy == True: 125 | n_frames = batch_size // 4 126 | is_cross = encoder_hidden_states is not None 127 | 128 | encoder_hidden_states = encoder_hidden_states if is_cross else x 129 | q = self.to_q(x) 130 | k = self.to_k(encoder_hidden_states) 131 | v = self.to_v(encoder_hidden_states) 132 | 133 | if self.injection_schedule is not None and (self.t in self.injection_schedule): 134 | # inject unconditional 135 | q[2*n_frames:3 * n_frames] = q[:n_frames] 136 | k[2*n_frames:3 * n_frames] = k[:n_frames] 137 | # inject conditional 138 | q[3 * n_frames:] = q[n_frames:2*n_frames] 139 | k[3 * n_frames:] = k[n_frames:2*n_frames] 140 | 141 | n_keyframes = self.n_keyframes # 1g2l 142 | idx_keyframe = n_frames-n_keyframes 143 | n_orig_frames = n_frames-n_keyframes 144 | 145 | # keyframe indices 146 | fr_idx_low_src_unc = idx_keyframe 147 | fr_idx_high_src_unc = idx_keyframe+n_keyframes 148 | fr_idx_low_src_cnd = n_frames+idx_keyframe 149 | fr_idx_high_src_cnd = n_frames+idx_keyframe+n_keyframes 150 | fr_idx_low_unc = 2*n_frames+idx_keyframe 151 | fr_idx_high_unc =2*n_frames+idx_keyframe+n_keyframes 152 | fr_idx_low_cnd = 3*n_frames+idx_keyframe 153 | fr_idx_high_cnd = 3*n_frames+idx_keyframe+n_keyframes 154 | 155 | # K 156 | k_src_unc = prepare_extended_QKV(self, k[fr_idx_low_src_unc:fr_idx_high_src_unc], n_keyframes, n_orig_frames, sequence_length, h, dim) 157 | k_src_cnd = prepare_extended_QKV(self, k[fr_idx_low_src_cnd:fr_idx_high_src_cnd], n_keyframes, n_orig_frames, sequence_length, h, dim) 158 | k_uncond = prepare_extended_QKV(self, k[fr_idx_low_unc:fr_idx_high_unc], n_keyframes, n_orig_frames, sequence_length, h, dim) 159 | k_cond = prepare_extended_QKV(self, k[fr_idx_low_cnd:fr_idx_high_cnd], n_keyframes, n_orig_frames, sequence_length, h, dim) 160 | # V 161 | v_src_unc = prepare_extended_QKV(self, v[fr_idx_low_src_unc:fr_idx_high_src_unc], n_keyframes, n_orig_frames, sequence_length, h, dim) 162 | v_src_cnd = prepare_extended_QKV(self, v[fr_idx_low_src_cnd:fr_idx_high_src_cnd], n_keyframes, n_orig_frames, sequence_length, h, dim) 163 | v_uncond = prepare_extended_QKV(self, v[fr_idx_low_unc:fr_idx_high_unc], n_keyframes, n_orig_frames, sequence_length, h, dim) 164 | v_cond = prepare_extended_QKV(self, v[fr_idx_low_cnd:fr_idx_high_cnd], n_keyframes, n_orig_frames, sequence_length, h, dim) 165 | 166 | # Q 167 | q_src_unc = prepare_extended_QKV(self, q[:n_orig_frames], 1, n_orig_frames, sequence_length, h, dim, is_SA=True) 168 | q_src_cnd = prepare_extended_QKV(self, q[n_frames: n_frames+n_orig_frames], 1, n_orig_frames, sequence_length, h, dim, is_SA=True) 169 | q_uncond = prepare_extended_QKV(self, q[2*n_frames: 2*n_frames+n_orig_frames], 1, n_orig_frames, sequence_length, h, dim, is_SA=True) 170 | q_cond = prepare_extended_QKV(self, q[3*n_frames: 3*n_frames+n_orig_frames], 1, n_orig_frames, sequence_length, h, dim, is_SA=True) 171 | 172 | out_source_unc_all = [] 173 | out_source_cnd_all = [] 174 | out_uncond_all = [] 175 | out_cond_all = [] 176 | 177 | single_batch = True # n_frames <= 12 178 | b = n_orig_frames if single_batch else 1 179 | 180 | 181 | for frame in range(0, n_orig_frames, b): 182 | out_source_unc = [] 183 | out_source_cnd = [] 184 | out_uncond = [] 185 | out_cond = [] 186 | for j in range(h): 187 | sim_source_unc_b = torch.bmm(q_src_unc[frame: frame + b, j], 188 | k_src_unc[frame: frame + b, j].transpose(-1, -2)) * self.scale 189 | sim_source_cnd_b = torch.bmm(q_src_cnd[frame: frame + b, j], 190 | k_src_cnd[frame: frame + b, j].transpose(-1, -2)) * self.scale 191 | sim_uncond_b = torch.bmm(q_uncond[frame: frame + b, j], 192 | k_uncond[frame: frame + b, j].transpose(-1, -2)) * self.scale 193 | sim_cond = torch.bmm(q_cond[frame: frame + b, j], 194 | k_cond[frame: frame + b, j].transpose(-1, -2)) * self.scale 195 | 196 | out_source_unc.append(torch.bmm(sim_source_unc_b.softmax(dim=-1), v_src_unc[frame: frame + b, j])) 197 | out_source_cnd.append(torch.bmm(sim_source_cnd_b.softmax(dim=-1), v_src_cnd[frame: frame + b, j])) 198 | out_uncond.append(torch.bmm(sim_uncond_b.softmax(dim=-1), v_uncond[frame: frame + b, j])) 199 | out_cond.append(torch.bmm(sim_cond.softmax(dim=-1), v_cond[frame: frame + b, j])) 200 | 201 | out_source_unc = torch.cat(out_source_unc, dim=0) 202 | out_source_cnd = torch.cat(out_source_cnd, dim=0) 203 | out_uncond = torch.cat(out_uncond, dim=0) 204 | out_cond = torch.cat(out_cond, dim=0) 205 | if single_batch: 206 | out_source_unc = out_source_unc.view(h, n_orig_frames, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 207 | h * n_orig_frames, sequence_length, -1) 208 | out_source_cnd = out_source_cnd.view(h, n_orig_frames, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 209 | h * n_orig_frames, sequence_length, -1) 210 | out_uncond = out_uncond.view(h, n_orig_frames, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 211 | h * n_orig_frames, sequence_length, -1) 212 | out_cond = out_cond.view(h, n_orig_frames, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 213 | h * n_orig_frames, sequence_length, -1) 214 | out_source_unc_all.append(out_source_unc) 215 | out_source_cnd_all.append(out_source_cnd) 216 | out_uncond_all.append(out_uncond) 217 | out_cond_all.append(out_cond) 218 | 219 | out_source_unc = torch.cat(out_source_unc_all, dim=0) 220 | out_source_cnd = torch.cat(out_source_cnd_all, dim=0) 221 | out_uncond = torch.cat(out_uncond_all, dim=0) 222 | out_cond = torch.cat(out_cond_all, dim=0) 223 | 224 | # self attention only for the keyframes 225 | # Q 226 | q_kf_src_unc = prepare_extended_QKV(self, q[fr_idx_low_src_unc:fr_idx_high_src_unc], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 227 | q_kf_src_cnd = prepare_extended_QKV(self, q[fr_idx_low_src_cnd:fr_idx_high_src_cnd], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 228 | q_kf_unc = prepare_extended_QKV(self, q[fr_idx_low_unc:fr_idx_high_unc], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 229 | q_kf_cnd = prepare_extended_QKV(self, q[fr_idx_low_cnd:fr_idx_high_cnd], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 230 | # K 231 | k_kf_src_unc = prepare_extended_QKV(self, k[fr_idx_low_src_unc:fr_idx_high_src_unc], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 232 | k_kf_src_cnd = prepare_extended_QKV(self, k[fr_idx_low_src_cnd:fr_idx_high_src_cnd], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 233 | k_kf_unc = prepare_extended_QKV(self, k[fr_idx_low_unc:fr_idx_high_unc], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 234 | k_kf_cnd = prepare_extended_QKV(self, k[fr_idx_low_cnd:fr_idx_high_cnd], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 235 | # V 236 | v_kf_src_unc = prepare_extended_QKV(self, v[fr_idx_low_src_unc:fr_idx_high_src_unc], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 237 | v_kf_src_cnd = prepare_extended_QKV(self, v[fr_idx_low_src_cnd:fr_idx_high_src_cnd], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 238 | v_kf_unc = prepare_extended_QKV(self, v[fr_idx_low_unc:fr_idx_high_unc], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 239 | v_kf_cnd = prepare_extended_QKV(self, v[fr_idx_low_cnd:fr_idx_high_cnd], 1, n_keyframes, sequence_length, h, dim, is_SA=True) 240 | 241 | out_kf_src_unc_all = [] 242 | out_kf_src_cnd_all = [] 243 | out_kf_unc_all = [] 244 | out_kf_cnd_all = [] 245 | 246 | b = n_keyframes if single_batch else 1 247 | for frame in range(0, n_keyframes, b): 248 | 249 | out_kf_src_unc = [] 250 | out_kf_src_cnd = [] 251 | out_kf_unc = [] 252 | out_kf_cnd = [] 253 | for j in range(h): 254 | sim_kf_src_unc = torch.bmm(q_kf_src_unc[frame: frame + b, j], 255 | k_kf_src_unc[frame: frame + b, j].transpose(-1, -2)) * self.scale 256 | sim_kf_src_cnd = torch.bmm(q_kf_src_cnd[frame: frame + b, j], 257 | k_kf_src_cnd[frame: frame + b, j].transpose(-1, -2)) * self.scale 258 | sim_kf_unc = torch.bmm(q_kf_unc[frame: frame + b, j], 259 | k_kf_unc[frame: frame + b, j].transpose(-1, -2)) * self.scale 260 | sim_kf_cnd = torch.bmm(q_kf_cnd[frame: frame + b, j], 261 | k_kf_cnd[frame: frame + b, j].transpose(-1, -2)) * self.scale 262 | 263 | out_kf_src_unc.append(torch.bmm(sim_kf_src_unc.softmax(dim=-1), v_kf_src_unc[frame: frame + b, j])) 264 | out_kf_src_cnd.append(torch.bmm(sim_kf_src_cnd.softmax(dim=-1), v_kf_src_cnd[frame: frame + b, j])) 265 | out_kf_unc.append(torch.bmm(sim_kf_unc.softmax(dim=-1), v_kf_unc[frame: frame + b, j])) 266 | out_kf_cnd.append(torch.bmm(sim_kf_cnd.softmax(dim=-1), v_kf_cnd[frame: frame + b, j])) 267 | 268 | out_kf_src_unc = torch.cat(out_kf_src_unc, dim=0) 269 | out_kf_src_cnd = torch.cat(out_kf_src_cnd, dim=0) 270 | out_kf_unc = torch.cat(out_kf_unc, dim=0) 271 | out_kf_cnd = torch.cat(out_kf_cnd, dim=0) 272 | if single_batch: 273 | out_kf_src_unc = out_kf_src_unc.view(h, n_keyframes, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 274 | h * n_keyframes, sequence_length, -1) 275 | out_kf_src_cnd = out_kf_src_cnd.view(h, n_keyframes, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 276 | h * n_keyframes, sequence_length, -1) 277 | out_kf_unc = out_kf_unc.view(h, n_keyframes, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 278 | h * n_keyframes, sequence_length, -1) 279 | out_kf_cnd = out_kf_cnd.view(h, n_keyframes, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 280 | h * n_keyframes, sequence_length, -1) 281 | out_kf_src_unc_all.append(out_kf_src_unc) 282 | out_kf_src_cnd_all.append(out_kf_src_cnd) 283 | out_kf_unc_all.append(out_kf_unc) 284 | out_kf_cnd_all.append(out_kf_cnd) 285 | 286 | out_kf_src_unc = torch.cat(out_kf_src_unc_all, dim=0) 287 | out_kf_src_cnd = torch.cat(out_kf_src_cnd_all, dim=0) 288 | out_kf_unc = torch.cat(out_kf_unc_all, dim=0) 289 | out_kf_cnd = torch.cat(out_kf_cnd_all, dim=0) 290 | 291 | out = torch.cat([out_source_unc, out_kf_src_unc, out_source_cnd, out_kf_src_cnd, out_uncond, out_kf_unc, out_cond, out_kf_cnd], dim=0) 292 | out = self.batch_to_head_dim(out) 293 | 294 | return to_out(out) 295 | 296 | # temporal slices 297 | else: 298 | n_frames = batch_size 299 | is_cross = encoder_hidden_states is not None 300 | encoder_hidden_states = encoder_hidden_states if is_cross else x 301 | q = self.to_q(x) 302 | k = self.to_k(encoder_hidden_states) 303 | v = self.to_v(encoder_hidden_states) 304 | 305 | # self attention only for temporal frames 306 | 307 | q = self.head_to_batch_dim(q) 308 | k = self.head_to_batch_dim(k) 309 | v = self.head_to_batch_dim(v) 310 | 311 | q = q.view(n_frames, h, sequence_length, dim // h) 312 | k = k.view(n_frames, h, sequence_length, dim // h) 313 | v = v.view(n_frames, h, sequence_length, dim // h) 314 | 315 | out_all = [] 316 | single_batch = True # n_frames <= 12 317 | b = n_frames if single_batch else 1 318 | for frame in range(0, n_frames, b): 319 | 320 | out_ = [] 321 | for j in range(h): 322 | sim = torch.bmm(q[frame: frame + b, j], 323 | k[frame: frame + b, j].transpose(-1, -2)) * self.scale 324 | 325 | out_.append(torch.bmm(sim.softmax(dim=-1), v[frame: frame + b, j])) 326 | 327 | out_ = torch.cat(out_, dim=0) 328 | 329 | if single_batch: 330 | out_ = out_.view(h, n_frames, sequence_length, dim // h).permute(1, 0, 2, 3).reshape( 331 | h * n_frames, sequence_length, -1) 332 | 333 | out_all.append(out_) 334 | 335 | out_= torch.cat(out_all, dim=0) 336 | 337 | out = out_ 338 | out = self.batch_to_head_dim(out) 339 | 340 | return to_out(out) 341 | 342 | return forward 343 | 344 | for _, module in model.unet.named_modules(): 345 | if isinstance_str(module, "BasicTransformerBlock"): 346 | module.attn1.forward = sa_forward(module.attn1) 347 | setattr(module.attn1, 'injection_schedule', []) 348 | 349 | 350 | res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} 351 | # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution 352 | for res in res_dict: 353 | for block in res_dict[res]: 354 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1 355 | module.forward = sa_forward(module) 356 | setattr(module, 'injection_schedule', injection_schedule) 357 | 358 | 359 | def register_conv_injection(model, injection_schedule): 360 | def conv_forward(self): 361 | def forward(input_tensor, temb, scale): 362 | scale = None 363 | hidden_states = input_tensor 364 | 365 | hidden_states = self.norm1(hidden_states) 366 | hidden_states = self.nonlinearity(hidden_states) 367 | 368 | if self.upsample is not None: 369 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 370 | if hidden_states.shape[0] >= 64: 371 | input_tensor = input_tensor.contiguous() 372 | hidden_states = hidden_states.contiguous() 373 | input_tensor = self.upsample(input_tensor) 374 | hidden_states = self.upsample(hidden_states) 375 | elif self.downsample is not None: 376 | input_tensor = self.downsample(input_tensor) 377 | hidden_states = self.downsample(hidden_states) 378 | 379 | hidden_states = self.conv1(hidden_states) 380 | 381 | if temb is not None: 382 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] 383 | 384 | if temb is not None and self.time_embedding_norm == "default": 385 | hidden_states = hidden_states + temb 386 | 387 | hidden_states = self.norm2(hidden_states) 388 | 389 | if temb is not None and self.time_embedding_norm == "scale_shift": 390 | scale, shift = torch.chunk(temb, 2, dim=1) 391 | hidden_states = hidden_states * (1 + scale) + shift 392 | 393 | hidden_states = self.nonlinearity(hidden_states) 394 | 395 | hidden_states = self.dropout(hidden_states) 396 | hidden_states = self.conv2(hidden_states) 397 | if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000): 398 | source_batch_size = int(hidden_states.shape[0] // 4) 399 | # inject unconditional src 400 | hidden_states[2*source_batch_size:3 * source_batch_size] = hidden_states[:source_batch_size] 401 | # inject conditional 402 | hidden_states[3 * source_batch_size:] = hidden_states[source_batch_size:2*source_batch_size] 403 | 404 | if self.conv_shortcut is not None: 405 | input_tensor = self.conv_shortcut(input_tensor) 406 | 407 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 408 | 409 | return output_tensor 410 | 411 | return forward 412 | 413 | conv_module = model.unet.up_blocks[1].resnets[1] # layer 4 only? 414 | conv_module.forward = conv_forward(conv_module) 415 | setattr(conv_module, 'injection_schedule', injection_schedule) 416 | -------------------------------------------------------------------------------- /Slicedit/video_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import torch 5 | import numpy as np 6 | import random 7 | import yaml 8 | 9 | import os 10 | import cv2 11 | 12 | 13 | from PIL import Image 14 | 15 | from PIL import Image, ImageDraw ,ImageFont 16 | import torchvision.transforms as T 17 | 18 | from torchvision.io import write_video 19 | 20 | def create_video_from_frames(frame_directory): 21 | 22 | video_directory = '/'.join(frame_directory.split('/')[:-1]) 23 | video_name = os.path.join(video_directory, "slicedit_out.mp4") 24 | 25 | # Check if video already exists 26 | # if so, skip this directory 27 | if os.path.exists(video_name): 28 | print(f"Video already exists for {video_name} !") 29 | return 30 | 31 | images = [img for img in sorted(os.listdir(frame_directory), key=lambda x: (len(x), x)) 32 | if img.endswith(".jpg") or 33 | img.endswith(".jpeg") or 34 | img.endswith("png")] 35 | 36 | if len(images) > 0: 37 | height, width, _ = cv2.imread(os.path.join(frame_directory, images[0])).shape 38 | 39 | video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*"mp4v"), 25, (width, height), True) 40 | 41 | print("Generating Video...") 42 | for image in images: 43 | curr_frame = cv2.imread(os.path.join(frame_directory, image)) 44 | 45 | video.write(curr_frame) 46 | 47 | cv2.destroyAllWindows() 48 | video.release() 49 | 50 | compressed_video_name = os.path.join(video_directory, "Slicedit_out_compressed.mp4") 51 | 52 | print(compressed_video_name) 53 | # Convert video to higher MPEG-4 compression 54 | os.system(f'ffmpeg -i "{video_name}" -c:v libx264 -crf 23 -preset medium -y "{compressed_video_name}"') 55 | 56 | print(f"Video created for {frame_directory}") 57 | 58 | 59 | # use this if you are having problems with ffmpeg 60 | def save_video(raw_frames, save_path, fps=25, orig_size=(512, 512)): 61 | video_codec = "libx264" 62 | video_options = { 63 | "crf": "23", 64 | "preset": "medium", 65 | } 66 | # squeeze and reshape frames to orig size using LANCOZ interpolation 67 | frames = torch.nn.functional.interpolate(raw_frames.squeeze(), orig_size, mode="area") 68 | 69 | frames = (((frames+1)/2).clamp(0, 1) * 255).to(torch.uint8).cpu().permute(0, 2, 3, 1) 70 | write_video(save_path, frames, fps=fps, video_codec=video_codec, options=video_options) 71 | 72 | 73 | def add_margin(pil_img, top = 0, right = 0, bottom = 0, 74 | left = 0, color = (255,255,255)): 75 | width, height = pil_img.size 76 | new_width = width + right + left 77 | new_height = height + top + bottom 78 | result = Image.new(pil_img.mode, (new_width, new_height), color) 79 | 80 | result.paste(pil_img, (left, top)) 81 | return result 82 | 83 | 84 | def tensor_to_pil(tensor_imgs): 85 | if type(tensor_imgs) == list: 86 | tensor_imgs = torch.cat(tensor_imgs) 87 | tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1) 88 | to_pil = T.ToPILImage() 89 | pil_imgs = [to_pil(img) for img in tensor_imgs] 90 | return pil_imgs 91 | 92 | 93 | def image_grid(imgs, rows = 1, cols = None, 94 | size = None, 95 | titles = None, text_pos = (0, 0)): 96 | if type(imgs) == list and type(imgs[0]) == torch.Tensor: 97 | imgs = torch.cat(imgs) 98 | if type(imgs) == torch.Tensor: 99 | imgs = tensor_to_pil(imgs) 100 | 101 | if not size is None: 102 | imgs = [img.resize((size,size)) for img in imgs] 103 | if cols is None: 104 | cols = len(imgs) 105 | assert len(imgs) >= rows*cols 106 | 107 | top=20 108 | w, h = imgs[0].size 109 | delta = 0 110 | if len(imgs)> 1 and not imgs[1].size[1] == h: 111 | delta = top 112 | h = imgs[1].size[1] 113 | if not titles is None: 114 | font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 115 | size = 20, encoding="unic") 116 | h = top + h 117 | grid = Image.new('RGB', size=(cols*w, rows*h+delta)) 118 | for i, img in enumerate(imgs): 119 | 120 | if not titles is None: 121 | img = add_margin(img, top = top, bottom = 0,left=0) 122 | draw = ImageDraw.Draw(img) 123 | draw.text(text_pos, titles[i],(0,0,0), 124 | font = font) 125 | if not delta == 0 and i > 0: 126 | grid.paste(img, box=(i%cols*w, i//cols*h+delta)) 127 | else: 128 | grid.paste(img, box=(i%cols*w, i//cols*h)) 129 | 130 | return grid 131 | 132 | 133 | def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None): 134 | if type(image_path) is str: 135 | image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3] 136 | else: 137 | image = image_path 138 | h, w, c = image.shape 139 | left = min(left, w-1) 140 | right = min(right, w - left - 1) 141 | top = min(top, h - left - 1) 142 | bottom = min(bottom, h - top - 1) 143 | image = image[top:h-bottom, left:w-right] 144 | h, w, c = image.shape 145 | if h < w: 146 | offset = (w - h) // 2 147 | image = image[:, offset:offset + h] 148 | elif w < h: 149 | offset = (h - w) // 2 150 | image = image[offset:offset + w] 151 | image = np.array(Image.fromarray(image).resize((512, 512))) 152 | image = torch.from_numpy(image).float() / 127.5 - 1 153 | image = image.permute(2, 0, 1).unsqueeze(0).to(device) 154 | 155 | return image 156 | 157 | 158 | def set_seed(seed=42): 159 | random.seed(seed) 160 | np.random.seed(seed) 161 | torch.manual_seed(seed) 162 | torch.cuda.manual_seed_all(seed) 163 | torch.backends.cudnn.deterministic = True 164 | torch.backends.cudnn.benchmark = False 165 | 166 | 167 | def extract_images(path, name, resize): 168 | print(f"Extracting frames from video...") 169 | # read the video from specified path 170 | cam = cv2.VideoCapture(path) 171 | if os.path.exists(f'data/{name}') and len(glob.glob(os.path.join(f'data/{name}/', '*.png')))>0: 172 | ret, frame = cam.read() 173 | return len(glob.glob(os.path.join(f'data/{name}/', '*.png'))), frame.shape 174 | try: 175 | # creating a folder named data 176 | if not os.path.exists(f'data/{name}'): 177 | os.makedirs(f'data/{name}') 178 | # if not created then raise error 179 | except OSError: 180 | print('Error: Creating directory of data') 181 | 182 | # frame 183 | current_frame = 0 184 | 185 | while(True): 186 | # reading frames from video 187 | ret, frame = cam.read() 188 | if ret: 189 | shape = frame.shape 190 | path_name = f'data/{name}/frame' + str(current_frame) + '.png' 191 | print ('Creating... ' + path_name) 192 | # saving the extracted frames and resize if needed 193 | if resize: 194 | resize_frame = cv2.resize(frame, dsize=(512,512), interpolation=cv2.INTER_CUBIC) 195 | cv2.imwrite(path_name, resize_frame) 196 | else: 197 | cv2.imwrite(path_name, frame) 198 | 199 | current_frame += 1 200 | last_frame = frame 201 | else: 202 | break 203 | if current_frame == 63 or current_frame == 62: 204 | while (current_frame != 64): 205 | path_name = f'data/{name}/frame' + str(current_frame) + '.png' 206 | print ('Copying... ' + path_name) 207 | 208 | if resize: 209 | resize_frame = cv2.resize(last_frame, dsize=(512,512), interpolation=cv2.INTER_CUBIC) 210 | cv2.imwrite(path_name, resize_frame) 211 | else: 212 | cv2.imwrite(path_name, last_frame) 213 | current_frame += 1 214 | if current_frame == 40: 215 | while (current_frame != 64): 216 | path_name = f'data/{name}/frame' + str(current_frame) + '.png' 217 | print ('Copying... ' + path_name) 218 | 219 | if resize: 220 | resize_frame = cv2.resize(last_frame, dsize=(512,512), interpolation=cv2.INTER_CUBIC) 221 | cv2.imwrite(path_name, resize_frame) 222 | else: 223 | cv2.imwrite(path_name, last_frame) 224 | current_frame += 1 225 | 226 | # release all space and windows once done 227 | cam.release() 228 | cv2.destroyAllWindows() 229 | return current_frame, shape 230 | 231 | 232 | def resize_image_to_original(path, save_path, size): 233 | img = cv2.imread(path) 234 | res = cv2.resize(img, dsize=size, interpolation=cv2.INTER_CUBIC) 235 | cv2.imwrite(save_path, res) 236 | 237 | 238 | def dataset_from_yaml(yaml_location): 239 | with open(yaml_location, 'r') as stream: 240 | data_loaded = yaml.safe_load(stream) 241 | 242 | return data_loaded 243 | -------------------------------------------------------------------------------- /Videos/parkour.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fallenshock/Slicedit/492eef3f6746a9ffea76202f0424c28aef1ed33a/Videos/parkour.mp4 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import glob 5 | import torch 6 | 7 | from diffusers import StableDiffusionPipeline, DDIMScheduler 8 | 9 | from Slicedit.inversion_utils import slicedit_ddpm_inversion_loop 10 | from Slicedit.video_utils import load_512, extract_images, set_seed, dataset_from_yaml 11 | 12 | 13 | def print_details(args, skip, cfg_dec, prompt_enc, prompt_tar, qk_inj, alpha, nbr_frames): 14 | print("---------") 15 | print("Working on:") 16 | print("video name: " + args.video_name) 17 | print("video path: " + args.video_path) 18 | print("exp_name: " + args.exp_name) 19 | print("skip: " + str(skip)) 20 | print("decoder: " + str(cfg_dec)) 21 | print("num_diffusion_steps: " + str(args.num_diffusion_steps)) 22 | print("prompt encoder: " + prompt_enc) 23 | print("prompt decoder: " + prompt_tar) 24 | print("alpha: " + str(alpha)) 25 | print("qk_inj: " + str(qk_inj)) 26 | print("nbr_frames: " + str(nbr_frames)) 27 | print("---------") 28 | 29 | ## main function prepares the data and runs slicedit for all videos and prompts in the dataset yaml file 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | ## path to the yaml file with the dataset parameters. If not provided, argparse values are used! 34 | parser.add_argument("--dataset_yaml", type=str, default="./yaml_files/dataset_configs/parkour.yaml") 35 | ## path to the yaml file with the experiment parameters. If not provided, argparse values are used! 36 | parser.add_argument("--exp_config", type=str, default="./yaml_files/exp_configs/default_exp_params.yaml") # 37 | ## argpase parameters (overwritten by yaml files if provided) 38 | # general parameters 39 | parser.add_argument("--device_num", type=int, default=0) 40 | parser.add_argument("--seed", type=int, default=42) 41 | parser.add_argument("--exp_name", default="slicedit_test_") 42 | parser.add_argument("--skip", type=int, default=8) 43 | parser.add_argument("--alpha", type=float, default=0.2) 44 | parser.add_argument("--use_negative_tar_prompt", action='store_true') 45 | parser.add_argument("--x_and_y", action='store_true') 46 | # prompt parameters 47 | parser.add_argument("--prompt_enc", default="") 48 | parser.add_argument("--prompt_dec", default="") 49 | parser.add_argument("--cfg_enc", type=float, default=3.5) 50 | parser.add_argument("--cfg_dec", type=int, default=10) 51 | # percentage of attention and conv feature injection, similarly to https://arxiv.org/abs/2211.12572 52 | parser.add_argument("--qk_inj", type=int, default=85) # 100% means injection in all steps (num_diffusion_steps-skip) 53 | parser.add_argument("--conv_inj", type=int, default=0) # 100% means injection in all steps (num_diffusion_steps-skip) 54 | # diffusion parameters 55 | parser.add_argument("--model_id", type=str, default ="stabilityai/stable-diffusion-2-1-base") 56 | parser.add_argument("--num_diffusion_steps", type=int, default=50) 57 | parser.add_argument("--eta", type=float, default=1) 58 | # video parameters (default taken from dataset yaml) 59 | parser.add_argument("--video_path", type=str, default ="") 60 | parser.add_argument("--video_name", type=str, default ="") 61 | # negative time prompt for the ST slices (off by default cfg_time=1) 62 | parser.add_argument("--cfg_time", type=float, default=1) 63 | parser.add_argument("--prompt_time", default="") 64 | parser.add_argument("--negative_time_prompt", type=str, default="jittery") 65 | # enable xformers for (slightly faster) inference (if supported) 66 | parser.add_argument("--use_xformers", action='store_true') 67 | 68 | args = parser.parse_args() 69 | 70 | ## parse the yaml files 71 | 72 | # overwrite argparse with experiment yaml 73 | if args.exp_config != "": 74 | # iterate over all the options from config yaml and change args accordingly 75 | config_yaml_dict = dataset_from_yaml(args.exp_config) 76 | for key in config_yaml_dict.keys(): 77 | setattr(args, key, config_yaml_dict[key]) 78 | 79 | # overwrite argparse with dataset yaml 80 | if args.dataset_yaml != "": 81 | full_data = dataset_from_yaml(args.dataset_yaml) 82 | else: 83 | full_data = [{'video_name': args.video_name,'source_prompt': args.prompt_enc,'target_prompts': [args.prompt_dec]}] 84 | # set seed 85 | seed = args.seed 86 | set_seed(seed) 87 | 88 | assert args.eta > 0, "eta must be greater than 0 for DDPM" 89 | 90 | device = f"cuda:{args.device_num}" 91 | 92 | # load/reload model: "stabilityai/stable-diffusion-2-1-base" 93 | ldm_stable = StableDiffusionPipeline.from_pretrained(args.model_id).to(device) 94 | if args.use_xformers: 95 | ldm_stable.enable_xformers_memory_efficient_attention() 96 | 97 | ldm_stable.scheduler = DDIMScheduler.from_pretrained(args.model_id, subfolder="scheduler") 98 | ldm_stable.scheduler.set_timesteps(args.num_diffusion_steps) 99 | 100 | prompt_time = args.prompt_time 101 | 102 | print(full_data) 103 | 104 | # run slicedit for all videos in the provided dataset yaml 105 | for i in range(len(full_data)): 106 | current_video_data = full_data[i] 107 | args.video_name = current_video_data['video_name'] 108 | print(args.video_name) 109 | 110 | args.video_path = './Videos/' + args.video_name + '.mp4' 111 | 112 | print(args.video_path) 113 | print(args.exp_name) 114 | prompt_enc = current_video_data.get('source_prompt', "") # default empty string 115 | prompt_tar_list = current_video_data['target_prompts'] 116 | 117 | # extract frames from the video and resize them 118 | frame_number, orig_frame_size = extract_images(args.video_path, f'{args.video_name}', resize=True) 119 | torch.save(torch.tensor(orig_frame_size), f'data/{args.video_name}/orig_frame_size.pt') 120 | 121 | # verify there is enough frames 122 | if frame_number < 64: 123 | print('Error : not enough frames') 124 | sys.exit() 125 | if frame_number > 210: 126 | print('clipping video to 210 frames') 127 | frame_number = 210 128 | 129 | print(full_data[i]) 130 | 131 | 132 | # run slicedit for all target prompts 133 | for k in range(len(prompt_tar_list)): 134 | prompt_tar = prompt_tar_list[k] 135 | print_details(args,args.skip,args.cfg_dec,prompt_enc,prompt_tar,args.qk_inj,args.alpha, frame_number) 136 | 137 | eps = [torch.randn(1, 4, 64, 64, device=device) for _ in range(args.num_diffusion_steps)] 138 | # encode video with VAE 139 | ws_video = [] 140 | for j in range(frame_number): 141 | image_path = f'data/{args.video_name}/frame{j}.png' 142 | offsets=(0,0,0,0) 143 | x0 = load_512(image_path, *offsets, device) 144 | # vae encode image 145 | with torch.autocast("cuda"), torch.inference_mode(): 146 | w0 = (ldm_stable.vae.encode(x0).latent_dist.mode() * 0.18215).float() 147 | ws_video.append(w0) 148 | 149 | save_path = f'output_data/{args.video_name}_{args.exp_name}/{prompt_tar}/d_stps_{args.num_diffusion_steps}_alpha_{args.alpha}_cfg_e_{args.cfg_enc}_cfg_d_{args.cfg_dec}_skip_{args.skip}_seed_{seed}_qk_inj_{args.qk_inj}' 150 | 151 | try: 152 | os.makedirs(save_path, exist_ok=False) 153 | print(save_path) 154 | except: 155 | 156 | if len(glob.glob(save_path + 'frames/*.png') ) > 0: 157 | print(save_path, "already exists") 158 | continue 159 | else: 160 | print(save_path, "path exists but no frames found, continuing...") 161 | 162 | # save all args to txt file in save_path 163 | try: 164 | with open(save_path + '/args.txt', 'w') as f: 165 | for arg in vars(args): 166 | print(f"{arg}: {getattr(args, arg)}", file=f) 167 | print("", file=f) 168 | except: 169 | print("Failed to save args.txt") 170 | 171 | slicedit_ddpm_inversion_loop(ldm_stable, ws_video, 172 | etas = args.eta, 173 | prog_bar = True, 174 | src_prompt = prompt_enc, 175 | tar_prompt = prompt_tar, 176 | prompt_time = args.prompt_time, 177 | cfg_scale = args.cfg_enc, 178 | alpha=args.alpha, 179 | num_inference_steps=args.num_diffusion_steps, 180 | eps = eps, 181 | n_frames=frame_number, 182 | skip=args.skip, 183 | cfg_scale_tar=args.cfg_dec, 184 | save_path=save_path, 185 | orig_frame_size=orig_frame_size, 186 | qk_injection_t=args.qk_inj, 187 | conv_injection_t=args.conv_inj, 188 | negative_time_prompt=args.negative_time_prompt, 189 | cfg_scale_time=args.cfg_time, 190 | use_negative_tar_prompt=args.use_negative_tar_prompt, 191 | x_and_y=args.x_and_y) 192 | 193 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | av==12.0.0 3 | certifi==2024.2.2 4 | charset-normalizer==3.3.2 5 | diffusers==0.21.2 6 | einops==0.8.0 7 | filelock==3.14.0 8 | fsspec==2024.5.0 9 | huggingface-hub==0.23.0 10 | idna==3.7 11 | importlib_metadata==7.1.0 12 | Jinja2==3.1.4 13 | MarkupSafe==2.1.5 14 | mpmath==1.3.0 15 | networkx==3.2.1 16 | numpy==1.26.4 17 | nvidia-cublas-cu12==12.1.3.1 18 | nvidia-cuda-cupti-cu12==12.1.105 19 | nvidia-cuda-nvrtc-cu12==12.1.105 20 | nvidia-cuda-runtime-cu12==12.1.105 21 | nvidia-cudnn-cu12==8.9.2.26 22 | nvidia-cufft-cu12==11.0.2.54 23 | nvidia-curand-cu12==10.3.2.106 24 | nvidia-cusolver-cu12==11.4.5.107 25 | nvidia-cusparse-cu12==12.1.0.106 26 | nvidia-nccl-cu12==2.20.5 27 | nvidia-nvjitlink-cu12==12.4.127 28 | nvidia-nvtx-cu12==12.1.105 29 | opencv-python==4.9.0.80 30 | packaging==24.0 31 | pillow==10.3.0 32 | psutil==5.9.8 33 | PyYAML==6.0.1 34 | regex==2024.5.15 35 | requests==2.32.1 36 | safetensors==0.4.3 37 | sympy==1.12 38 | tokenizers==0.19.1 39 | torch==2.3.0 40 | torchaudio==2.3.0 41 | torchvision==0.18.0 42 | tqdm==4.66.4 43 | transformers==4.41.0 44 | triton==2.3.0 45 | typing_extensions==4.11.0 46 | urllib3==2.2.1 47 | zipp==3.18.2 48 | -------------------------------------------------------------------------------- /yaml_files/dataset_configs/parkour.yaml: -------------------------------------------------------------------------------- 1 | - 2 | video_name: parkour 3 | source_prompt: a man is jumping 4 | 5 | target_prompts: 6 | - a shiny silver robot is jumping 7 | -------------------------------------------------------------------------------- /yaml_files/exp_configs/default_exp_params.yaml: -------------------------------------------------------------------------------- 1 | device_num: 0 2 | exp_name: slicedit_exp_0 3 | cfg_enc: 3.5 4 | cfg_dec: 10 5 | skip: 8 6 | qk_inj: 85 7 | cfg_time: 1 8 | num_diffusion_steps: 50 9 | eta: 1 10 | model_id: stabilityai/stable-diffusion-2-1-base 11 | alpha: 0.2 12 | seed: 42 13 | conv_inj: 0 14 | negative_time_prompt: jittery 15 | x_and_y: false --------------------------------------------------------------------------------