├── assets └── example │ ├── example_01.png │ ├── example_02.png │ ├── example_03.png │ └── example_04.png ├── .gitignore ├── configs ├── prompts │ └── default.yaml ├── inference │ ├── inference.yaml │ └── inference_autoregress.yaml └── training │ └── training.yaml ├── environment.yaml ├── cog.yaml ├── LICENSE ├── consisti2v ├── utils │ ├── frameinit_utils.py │ └── util.py ├── models │ ├── rotary_embedding.py │ └── videoldm_transformer_blocks.py ├── data │ └── dataset.py └── pipelines │ ├── pipeline_autoregress_animation.py │ └── pipeline_conditional_animation.py ├── README.md ├── predict.py ├── scripts ├── animate.py └── animate_autoregress.py ├── app.py └── train.py /assets/example/example_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/ConsistI2V/HEAD/assets/example/example_01.png -------------------------------------------------------------------------------- /assets/example/example_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/ConsistI2V/HEAD/assets/example/example_02.png -------------------------------------------------------------------------------- /assets/example/example_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/ConsistI2V/HEAD/assets/example/example_03.png -------------------------------------------------------------------------------- /assets/example/example_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/ConsistI2V/HEAD/assets/example/example_04.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | samples/ 2 | wandb/ 3 | outputs/ 4 | __pycache__/ 5 | scripts/animate_inter.py 6 | scripts/gradio_app.py 7 | *.ipynb 8 | *.safetensors 9 | *.ckpt 10 | .ossutil_checkpoint/ 11 | ossutil_output/ 12 | debugs/ 13 | .vscode 14 | .env 15 | models 16 | !*/models 17 | .ipynb_checkpoints 18 | checkpoints -------------------------------------------------------------------------------- /configs/prompts/default.yaml: -------------------------------------------------------------------------------- 1 | seeds: random 2 | 3 | prompts: 4 | - "timelapse at the snow land with aurora in the sky." 5 | - "fireworks." 6 | - "clown fish swimming through the coral reef." 7 | - "melting ice cream dripping down the cone." 8 | 9 | n_prompts: 10 | - "" 11 | 12 | path_to_first_frames: 13 | - "assets/example/example_01.png" 14 | - "assets/example/example_02.png" 15 | - "assets/example/example_03.png" 16 | - "assets/example/example_04.png" -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: consisti2v 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.10 7 | - pytorch=2.1.0 8 | - torchvision=0.16.0 9 | - torchaudio=2.1.0 10 | - pytorch-cuda=11.8 11 | - pip 12 | - pip: 13 | - diffusers==0.21.2 14 | - transformers==4.25.1 15 | - accelerate==0.23.0 16 | - imageio==2.27.0 17 | - decord==0.6.0 18 | - einops 19 | - omegaconf 20 | - safetensors 21 | - gradio==3.42.0 22 | - wandb 23 | - moviepy 24 | - scikit-learn 25 | - av 26 | - rotary_embedding_torch 27 | - torchmetrics 28 | - torch-fidelity 29 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | system_packages: 7 | - "libgl1-mesa-glx" 8 | - "libglib2.0-0" 9 | python_version: "3.11" 10 | python_packages: 11 | - torch==2.0.1 12 | - torchvision==0.15.2 13 | - diffusers==0.21.2 14 | - transformers==4.25.1 15 | - accelerate==0.23.0 16 | - imageio==2.27.0 17 | - decord==0.6.0 18 | - einops 19 | - omegaconf 20 | - safetensors 21 | - wandb 22 | - moviepy 23 | - scikit-learn 24 | - av 25 | - rotary_embedding_torch 26 | - torchmetrics 27 | - torch-fidelity 28 | run: 29 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 30 | predict: "predict.py:Predictor" 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 TIGER Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/inference/inference.yaml: -------------------------------------------------------------------------------- 1 | output_dir: "samples/inference" 2 | output_name: "i2v" 3 | 4 | pretrained_model_path: "TIGER-Lab/ConsistI2V" 5 | unet_path: null 6 | unet_ckpt_prefix: "module." 7 | pipeline_pretrained_path: null 8 | 9 | sampling_kwargs: 10 | height: 256 11 | width: 256 12 | n_frames: 16 13 | steps: 50 14 | ddim_eta: 0.0 15 | guidance_scale_txt: 7.5 16 | guidance_scale_img: 1.0 17 | guidance_rescale: 0.0 18 | num_videos_per_prompt: 1 19 | frame_stride: 3 20 | 21 | unet_additional_kwargs: 22 | variant: null 23 | n_temp_heads: 8 24 | augment_temporal_attention: true 25 | temp_pos_embedding: "rotary" # "rotary" or "sinusoidal" 26 | first_frame_condition_mode: "concat" 27 | use_frame_stride_condition: true 28 | noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive" 29 | noise_alpha: 1.0 30 | 31 | noise_scheduler_kwargs: 32 | beta_start: 0.00085 33 | beta_end: 0.012 34 | beta_schedule: "linear" 35 | steps_offset: 1 36 | clip_sample: false 37 | rescale_betas_zero_snr: false # true if using zero terminal snr 38 | timestep_spacing: "leading" # "trailing" if using zero terminal snr 39 | prediction_type: "epsilon" # "v_prediction" if using zero terminal snr 40 | 41 | frameinit_kwargs: 42 | enable: true 43 | camera_motion: null 44 | noise_level: 850 45 | filter_params: 46 | method: 'gaussian' 47 | d_s: 0.25 48 | d_t: 0.25 -------------------------------------------------------------------------------- /configs/inference/inference_autoregress.yaml: -------------------------------------------------------------------------------- 1 | output_dir: "samples/inference" 2 | output_name: "long_video" 3 | 4 | pretrained_model_path: "TIGER-Lab/ConsistI2V" 5 | unet_path: null 6 | unet_ckpt_prefix: "module." 7 | pipeline_pretrained_path: null 8 | 9 | sampling_kwargs: 10 | height: 256 11 | width: 256 12 | n_frames: 16 13 | steps: 50 14 | ddim_eta: 0.0 15 | guidance_scale_txt: 7.5 16 | guidance_scale_img: 1.0 17 | guidance_rescale: 0.0 18 | num_videos_per_prompt: 1 19 | frame_stride: 3 20 | autoregress_steps: 3 21 | 22 | unet_additional_kwargs: 23 | variant: null 24 | n_temp_heads: 8 25 | augment_temporal_attention: true 26 | temp_pos_embedding: "rotary" # "rotary" or "sinusoidal" 27 | first_frame_condition_mode: "concat" 28 | use_frame_stride_condition: true 29 | noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive" 30 | noise_alpha: 1.0 31 | 32 | noise_scheduler_kwargs: 33 | beta_start: 0.00085 34 | beta_end: 0.012 35 | beta_schedule: "linear" 36 | steps_offset: 1 37 | clip_sample: false 38 | rescale_betas_zero_snr: false # true if using zero terminal snr 39 | timestep_spacing: "leading" # "trailing" if using zero terminal snr 40 | prediction_type: "epsilon" # "v_prediction" if using zero terminal snr 41 | 42 | 43 | frameinit_kwargs: 44 | enable: true 45 | noise_level: 850 46 | filter_params: 47 | method: 'gaussian' 48 | d_s: 0.25 49 | d_t: 0.25 -------------------------------------------------------------------------------- /configs/training/training.yaml: -------------------------------------------------------------------------------- 1 | output_dir: "checkpoints" 2 | pretrained_model_path: "stabilityai/stable-diffusion-2-1-base" 3 | 4 | noise_scheduler_kwargs: 5 | num_train_timesteps: 1000 6 | beta_start: 0.00085 7 | beta_end: 0.012 8 | beta_schedule: "linear" 9 | steps_offset: 1 10 | clip_sample: false 11 | rescale_betas_zero_snr: false # true if using zero terminal snr 12 | timestep_spacing: "leading" # "trailing" if using zero terminal snr 13 | prediction_type: "epsilon" # "v_prediction" if using zero terminal snr 14 | 15 | train_data: 16 | dataset: "joint" 17 | pexels_config: 18 | enable: false 19 | json_path: null 20 | caption_json_path: null 21 | video_folder: null 22 | webvid_config: 23 | enable: true 24 | json_path: "/path/to/webvid/annotation" 25 | video_folder: "/path/to/webvid/data" 26 | sample_size: 256 27 | sample_duration: null 28 | sample_fps: null 29 | sample_stride: [1, 5] 30 | sample_n_frames: 16 31 | 32 | validation_data: 33 | prompts: 34 | - "timelapse at the snow land with aurora in the sky." 35 | - "fireworks." 36 | - "clown fish swimming through the coral reef." 37 | - "melting ice cream dripping down the cone." 38 | 39 | path_to_first_frames: 40 | - "assets/example/example_01.jpg" 41 | - "assets/example/example_02.jpg" 42 | - "assets/example/example_03.jpg" 43 | - "assets/example/example_04.jpg" 44 | 45 | num_inference_steps: 50 46 | ddim_eta: 0.0 47 | guidance_scale_txt: 7.5 48 | guidance_scale_img: 1.0 49 | guidance_rescale: 0.0 50 | frame_stride: 3 51 | 52 | trainable_modules: 53 | - "all" 54 | # - "conv3ds." 55 | # - "tempo_attns." 56 | 57 | resume_from_checkpoint: null 58 | 59 | unet_additional_kwargs: 60 | variant: null 61 | n_temp_heads: 8 62 | augment_temporal_attention: true 63 | temp_pos_embedding: "rotary" # "rotary" or "sinusoidal" 64 | first_frame_condition_mode: "concat" 65 | use_frame_stride_condition: true 66 | noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive" 67 | noise_alpha: 1.0 68 | 69 | cfg_random_null_text_ratio: 0.1 70 | cfg_random_null_img_ratio: 0.1 71 | 72 | use_ema: false 73 | ema_decay: 0.9999 74 | 75 | learning_rate: 5.e-5 76 | train_batch_size: 3 77 | gradient_accumulation_steps: 1 78 | max_grad_norm: 0.5 79 | 80 | max_train_epoch: -1 81 | max_train_steps: 200000 82 | checkpointing_epochs: -1 83 | checkpointing_steps: 2000 84 | validation_steps: 1000 85 | 86 | seed: 42 87 | mixed_precision: "bf16" 88 | num_workers: 32 89 | enable_xformers_memory_efficient_attention: true 90 | 91 | is_image: false 92 | is_debug: false 93 | -------------------------------------------------------------------------------- /consisti2v/utils/frameinit_utils.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/TianxingWu/FreeInit/blob/master/freeinit_utils.py 2 | import torch 3 | import torch.fft as fft 4 | import math 5 | 6 | 7 | def freq_mix_3d(x, noise, LPF): 8 | """ 9 | Noise reinitialization. 10 | 11 | Args: 12 | x: diffused latent 13 | noise: randomly sampled noise 14 | LPF: low pass filter 15 | """ 16 | # FFT 17 | x_freq = fft.fftn(x, dim=(-3, -2, -1)) 18 | x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) 19 | noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) 20 | noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) 21 | 22 | # frequency mix 23 | HPF = 1 - LPF 24 | x_freq_low = x_freq * LPF 25 | noise_freq_high = noise_freq * HPF 26 | x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain 27 | 28 | # IFFT 29 | x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) 30 | x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real 31 | 32 | return x_mixed 33 | 34 | 35 | def get_freq_filter(shape, device, filter_type, n, d_s, d_t): 36 | """ 37 | Form the frequency filter for noise reinitialization. 38 | 39 | Args: 40 | shape: shape of latent (B, C, T, H, W) 41 | filter_type: type of the freq filter 42 | n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian 43 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 44 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 45 | """ 46 | if filter_type == "gaussian": 47 | return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 48 | elif filter_type == "ideal": 49 | return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 50 | elif filter_type == "box": 51 | return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) 52 | elif filter_type == "butterworth": 53 | return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) 54 | else: 55 | raise NotImplementedError 56 | 57 | 58 | def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): 59 | """ 60 | Compute the gaussian low pass filter mask. 61 | 62 | Args: 63 | shape: shape of the filter (volume) 64 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 65 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 66 | """ 67 | T, H, W = shape[-3], shape[-2], shape[-1] 68 | mask = torch.zeros(shape) 69 | if d_s==0 or d_t==0: 70 | return mask 71 | for t in range(T): 72 | for h in range(H): 73 | for w in range(W): 74 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 75 | mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) 76 | return mask 77 | 78 | 79 | def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): 80 | """ 81 | Compute the butterworth low pass filter mask. 82 | 83 | Args: 84 | shape: shape of the filter (volume) 85 | n: order of the filter, larger n ~ ideal, smaller n ~ gaussian 86 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 87 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 88 | """ 89 | T, H, W = shape[-3], shape[-2], shape[-1] 90 | mask = torch.zeros(shape) 91 | if d_s==0 or d_t==0: 92 | return mask 93 | for t in range(T): 94 | for h in range(H): 95 | for w in range(W): 96 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 97 | mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) 98 | return mask 99 | 100 | 101 | def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): 102 | """ 103 | Compute the ideal low pass filter mask. 104 | 105 | Args: 106 | shape: shape of the filter (volume) 107 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 108 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 109 | """ 110 | T, H, W = shape[-3], shape[-2], shape[-1] 111 | mask = torch.zeros(shape) 112 | if d_s==0 or d_t==0: 113 | return mask 114 | for t in range(T): 115 | for h in range(H): 116 | for w in range(W): 117 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) 118 | mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 119 | return mask 120 | 121 | 122 | def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): 123 | """ 124 | Compute the ideal low pass filter mask (approximated version). 125 | 126 | Args: 127 | shape: shape of the filter (volume) 128 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0) 129 | d_t: normalized stop frequency for temporal dimension (0.0-1.0) 130 | """ 131 | T, H, W = shape[-3], shape[-2], shape[-1] 132 | mask = torch.zeros(shape) 133 | if d_s==0 or d_t==0: 134 | return mask 135 | 136 | threshold_s = round(int(H // 2) * d_s) 137 | threshold_t = round(T // 2 * d_t) 138 | 139 | cframe, crow, ccol = T // 2, H // 2, W //2 140 | mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 141 | 142 | return mask -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConsistI2V 2 | 3 | 4 | [**🌐 Homepage**](https://tiger-ai-lab.github.io/ConsistI2V/) | [**📖 arXiv**](https://arxiv.org/abs/2402.04324) | [**🤗 Model**](https://huggingface.co/TIGER-Lab/ConsistI2V) | [**📊 I2V-Bench**](https://drive.google.com/drive/folders/1eg_vtowKZBen74W-A1oeO4bR1K21giks) | [**🤗 Space**](https://huggingface.co/spaces/TIGER-Lab/ConsistI2V) | [**🎬 Replicate Demo**](https://replicate.com/wren93/consisti2v) 5 | 6 | This repo contains the codebase for our TMLR-2024 paper "[ConsistI2V: Enhancing Visual Consistency for Image-to-Video Generation](https://arxiv.org/abs/2402.04324)" 7 | 8 | We propose ConsistI2V, a diffusion-based method to enhance visual consistency for I2V generation. Specifically, we introduce (1) spatiotemporal attention over the first frame to maintain spatial and motion consistency, (2) noise initialization from the low-frequency band of the first frame to enhance layout consistency. These two approaches enable ConsistI2V to generate highly consistent videos. 9 | ConsistI2V 10 | 11 | ## 🔔News 12 | - **[2024-03-26]: Try our Gradio Demo on Huggingface [Space](https://huggingface.co/spaces/TIGER-Lab/ConsistI2V)! Thanks [@AK](https://twitter.com/_akhaliq) for the help.** 13 | - **[2024-03-21]: Add Gradio Demo. Run `python app.py` to launch the demo locally.** 14 | - **[2024-03-09]: Add Replicate [Demo](https://replicate.com/wren93/consisti2v). Thanks [@chenxwh](https://github.com/chenxwh) for the effort!** 15 | - **[2024-02-26]: Release code and [model](https://huggingface.co/TIGER-Lab/ConsistI2V) for ConsistI2V.** 16 | 17 | 18 | ## Environment Setup 19 | Prepare codebase and Conda environment using the following commands: 20 | ``` 21 | git clone https://github.com/TIGER-AI-Lab/ConsistI2V 22 | cd ConsistI2V 23 | 24 | conda env create -f environment.yaml 25 | conda activate consisti2v 26 | ``` 27 | 28 | ## Inference 29 | Our [model](https://huggingface.co/TIGER-Lab/ConsistI2V) is available for download on 🤗 Hugging Face. To generate videos with ConsistI2V, modify the inference configurations in `configs/inference/inference.yaml` and the input prompt file `configs/prompts/default.yaml`, and then run the sampling script with the following command: 30 | ``` 31 | python -m scripts.animate \ 32 | --inference_config configs/inference/inference.yaml \ 33 | --prompt_config configs/prompts/default.yaml \ 34 | --format mp4 35 | ``` 36 | The inference script automatically downloads the model from Hugging Face by specifying `pretrained_model_path` in `configs/inference/inference.yaml` as `TIGER-Lab/ConsistI2V` (default configuration). If you are having trouble downloading the model from the script, you can store the model on your local storage and modify `pretrained_model_path` to the local model path. 37 | 38 | You can also explicitly define the input text prompt, negative prompt, sampling seed and first frame path as: 39 | ``` 40 | python -m scripts.animate \ 41 | --inference_config configs/inference/inference.yaml \ 42 | --prompt "timelapse at the snow land with aurora in the sky." \ 43 | --n_prompt "your negative prompt" \ 44 | --seed 42 \ 45 | --path_to_first_frame assets/example/example_01.png \ 46 | --format mp4 47 | ``` 48 | 49 | To modify inference configurations in `configs/inference/inference.yaml` from command line, append extra arguments to the end of the inference command: 50 | ``` 51 | python -m scripts.animate \ 52 | --inference_config configs/inference/inference.yaml \ 53 | ... # additional arguments 54 | --format mp4 55 | sampling_kwargs.num_videos_per_prompt=4 \ # overwrite the configs in the config file 56 | frameinit_kwargs.filter_params.d_s=0.5 57 | ``` 58 | 59 | We also created a Gradio demo for easier use of ConsistI2V. The demo can be launched locally by running the following command: 60 | ``` 61 | conda activate consisti2v 62 | python app.py 63 | ``` 64 | By default, the demo will be running at `localhost:7860`. 65 | 66 | ## Training 67 | Modify the training configurations in `configs/training/training.yaml` and run the following command to train the model: 68 | ``` 69 | python -m torch.distributed.run \ 70 | --nproc_per_node=${GPU_PER_NODE} \ 71 | --master_addr=${MASTER_ADDR} \ 72 | --master_port=${MASTER_PORT} \ 73 | --nnodes=${NUM_NODES} \ 74 | --node_rank=${NODE_RANK} \ 75 | train.py \ 76 | --config configs/training/training.yaml \ 77 | -n consisti2v_training \ 78 | --wandb 79 | ``` 80 | where `GPU_PER_NODE`, `MASTER_ADDR`, `MASTER_PORT`, `NUM_NODES` and `NODE_RANK` can be defined based on your training environment. The dataloader in our code assumes a root folder `train_data.webvid_config.video_folder` containing all videos and a `jsonl` file `train_data.webvid_config.json_path` containing video relative paths and captions, with each line in the following format: 81 | ``` 82 | {"text": "A man rolling a winter sled with a child sitting on it in the snow close-up", "time": "30.030", "file": "relative/path/to/video.mp4", "fps": 29.97002997002997} 83 | ``` 84 | Videos can be stored in multiple subdirectories. Alternatively, you can modify the dataloader to support your own dataset. Similar to model inference, you can also add additional arguments at the end of the training command to modify the training configurations in `configs/training/training.yaml`. 85 | 86 | ## Citation 87 | Please kindly cite our paper if you find our code, data, models or results to be helpful. 88 | ```bibtex 89 | @article{ren2024consisti2v, 90 | title={ConsistI2V: Enhancing Visual Consistency for Image-to-Video Generation}, 91 | author={Ren, Weiming and Yang, Harry and Zhang, Ge and Wei, Cong and Du, Xinrun and Huang, Stephen and Chen, Wenhu}, 92 | journal={arXiv preprint arXiv:2402.04324}, 93 | year={2024} 94 | } 95 | ``` 96 | ## Acknowledgements 97 | Our codebase is built upon [AnimateDiff](https://github.com/guoyww/AnimateDiff), [FreeInit](https://github.com/TianxingWu/FreeInit) and 🤗 [diffusers](https://github.com/huggingface/diffusers). Thanks for open-sourcing. 98 | -------------------------------------------------------------------------------- /consisti2v/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | import wandb 10 | 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | 14 | from torchmetrics.image.fid import _compute_fid 15 | 16 | 17 | def zero_rank_print(s): 18 | if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 19 | 20 | 21 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, wandb=False, global_step=0, format="gif"): 22 | videos = rearrange(videos, "b c t h w -> t b c h w") 23 | outputs = [] 24 | for x in videos: 25 | x = torchvision.utils.make_grid(x, nrow=n_rows) 26 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 27 | if rescale: 28 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 29 | x = (x * 255).numpy().astype(np.uint8) 30 | outputs.append(x) 31 | 32 | if wandb: 33 | wandb_video = wandb.Video(outputs, fps=fps) 34 | wandb.log({"val_videos": wandb_video}, step=global_step) 35 | 36 | os.makedirs(os.path.dirname(path), exist_ok=True) 37 | if format == "gif": 38 | imageio.mimsave(path, outputs, fps=fps) 39 | elif format == "mp4": 40 | torchvision.io.write_video(path, np.array(outputs), fps=fps, video_codec='h264', options={'crf': '10'}) 41 | 42 | # DDIM Inversion 43 | @torch.no_grad() 44 | def init_prompt(prompt, pipeline): 45 | uncond_input = pipeline.tokenizer( 46 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 47 | return_tensors="pt" 48 | ) 49 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 50 | text_input = pipeline.tokenizer( 51 | [prompt], 52 | padding="max_length", 53 | max_length=pipeline.tokenizer.model_max_length, 54 | truncation=True, 55 | return_tensors="pt", 56 | ) 57 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 58 | context = torch.cat([uncond_embeddings, text_embeddings]) 59 | 60 | return context 61 | 62 | 63 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 64 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 65 | timestep, next_timestep = min( 66 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 67 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 68 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 69 | beta_prod_t = 1 - alpha_prod_t 70 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 71 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 72 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 73 | return next_sample 74 | 75 | 76 | def get_noise_pred_single(latents, t, context, first_frame_latents, frame_stride, unet): 77 | noise_pred = unet(latents, t, encoder_hidden_states=context, first_frame_latents=first_frame_latents, frame_stride=frame_stride).sample 78 | return noise_pred 79 | 80 | 81 | @torch.no_grad() 82 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt, first_frame_latents, frame_stride): 83 | context = init_prompt(prompt, pipeline) 84 | uncond_embeddings, cond_embeddings = context.chunk(2) 85 | all_latent = [latent] 86 | latent = latent.clone().detach() 87 | for i in tqdm(range(num_inv_steps)): 88 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 89 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, first_frame_latents, frame_stride, pipeline.unet) 90 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 91 | all_latent.append(latent) 92 | return all_latent 93 | 94 | 95 | @torch.no_grad() 96 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt="", first_frame_latents=None, frame_stride=3): 97 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt, first_frame_latents, frame_stride) 98 | return ddim_latents 99 | 100 | 101 | def compute_fid(real_features, fake_features, num_features, device): 102 | orig_dtype = real_features.dtype 103 | 104 | mx_num_feats = (num_features, num_features) 105 | real_features_sum = torch.zeros(num_features).double().to(device) 106 | real_features_cov_sum = torch.zeros(mx_num_feats).double().to(device) 107 | real_features_num_samples = torch.tensor(0).long().to(device) 108 | 109 | fake_features_sum = torch.zeros(num_features).double().to(device) 110 | fake_features_cov_sum = torch.zeros(mx_num_feats).double().to(device) 111 | fake_features_num_samples = torch.tensor(0).long().to(device) 112 | 113 | real_features = real_features.double() 114 | fake_features = fake_features.double() 115 | 116 | real_features_sum += real_features.sum(dim=0) 117 | real_features_cov_sum += real_features.t().mm(real_features) 118 | real_features_num_samples += real_features.shape[0] 119 | 120 | fake_features_sum += fake_features.sum(dim=0) 121 | fake_features_cov_sum += fake_features.t().mm(fake_features) 122 | fake_features_num_samples += fake_features.shape[0] 123 | 124 | """Calculate FID score based on accumulated extracted features from the two distributions.""" 125 | if real_features_num_samples < 2 or fake_features_num_samples < 2: 126 | raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID") 127 | mean_real = (real_features_sum / real_features_num_samples).unsqueeze(0) 128 | mean_fake = (fake_features_sum / fake_features_num_samples).unsqueeze(0) 129 | 130 | cov_real_num = real_features_cov_sum - real_features_num_samples * mean_real.t().mm(mean_real) 131 | cov_real = cov_real_num / (real_features_num_samples - 1) 132 | cov_fake_num = fake_features_cov_sum - fake_features_num_samples * mean_fake.t().mm(mean_fake) 133 | cov_fake = cov_fake_num / (fake_features_num_samples - 1) 134 | return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(orig_dtype) 135 | 136 | 137 | def compute_inception_score(gen_probs, num_splits=10): 138 | num_gen = gen_probs.shape[0] 139 | gen_probs = gen_probs.detach().cpu().numpy() 140 | scores = [] 141 | np.random.RandomState(42).shuffle(gen_probs) 142 | for i in range(num_splits): 143 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 144 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 145 | kl = np.mean(np.sum(kl, axis=1)) 146 | scores.append(np.exp(kl)) 147 | return float(np.mean(scores)), float(np.std(scores)) 148 | # idx = torch.randperm(features.shape[0]) 149 | # features = features[idx] 150 | # # calculate probs and logits 151 | # prob = features.softmax(dim=1) 152 | # log_prob = features.log_softmax(dim=1) 153 | 154 | # # split into groups 155 | # prob = prob.chunk(splits, dim=0) 156 | # log_prob = log_prob.chunk(splits, dim=0) 157 | 158 | # # calculate score per split 159 | # mean_prob = [p.mean(dim=0, keepdim=True) for p in prob] 160 | # kl_ = [p * (log_p - m_p.log()) for p, log_p, m_p in zip(prob, log_prob, mean_prob)] 161 | # kl_ = [k.sum(dim=1).mean().exp() for k in kl_] 162 | # kl = torch.stack(kl_) 163 | 164 | # return mean and std 165 | # return kl.mean(), kl.std() -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | 4 | import os 5 | import time 6 | import subprocess 7 | from omegaconf import OmegaConf 8 | import torch 9 | from cog import BasePredictor, Input, Path 10 | from diffusers import AutoencoderKL, DDIMScheduler 11 | from transformers import CLIPTextModel, CLIPTokenizer 12 | 13 | from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel 14 | from consisti2v.pipelines.pipeline_conditional_animation import ( 15 | ConditionalAnimationPipeline, 16 | ) 17 | from consisti2v.utils.util import save_videos_grid 18 | 19 | 20 | URL = { 21 | k: f"https://weights.replicate.delivery/default/ConsistI2V_cache/{k}.tar" 22 | for k in ["text_encoder", "vae", "tokenizer", "unet"] 23 | } 24 | MODEL_CACHE = { 25 | k: f"model_cache/{k}" for k in ["text_encoder", "vae", "tokenizer", "unet"] 26 | } 27 | 28 | 29 | def download_weights(url, dest): 30 | start = time.time() 31 | print("downloading url: ", url) 32 | print("downloading to: ", dest) 33 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False) 34 | print("downloading took: ", time.time() - start) 35 | 36 | 37 | class Predictor(BasePredictor): 38 | def setup(self) -> None: 39 | """Load the model into memory to make running multiple predictions efficient""" 40 | inference_config = "configs/inference/inference.yaml" 41 | self.config = OmegaConf.load(inference_config) 42 | noise_scheduler = DDIMScheduler( 43 | **OmegaConf.to_container(self.config.noise_scheduler_kwargs) 44 | ) 45 | 46 | # The weights are pushed to replicate.delivery, see def save_weights() below for details 47 | for k in ["text_encoder", "vae", "tokenizer", "unet"]: 48 | if not os.path.exists(MODEL_CACHE[k]): 49 | download_weights(URL[k], MODEL_CACHE[k]) 50 | 51 | tokenizer = CLIPTokenizer.from_pretrained( 52 | MODEL_CACHE["tokenizer"], use_safetensors=True 53 | ) 54 | text_encoder = CLIPTextModel.from_pretrained(MODEL_CACHE["text_encoder"]) 55 | vae = AutoencoderKL.from_pretrained(MODEL_CACHE["vae"], use_safetensors=True) 56 | unet = VideoLDMUNet3DConditionModel.from_pretrained( 57 | MODEL_CACHE["unet"], 58 | subfolder="unet", 59 | variant=self.config.unet_additional_kwargs["variant"], 60 | temp_pos_embedding=self.config.unet_additional_kwargs["temp_pos_embedding"], 61 | augment_temporal_attention=self.config.unet_additional_kwargs[ 62 | "augment_temporal_attention" 63 | ], 64 | use_temporal=True, 65 | n_frames=self.config.sampling_kwargs["n_frames"], 66 | n_temp_heads=self.config.unet_additional_kwargs["n_temp_heads"], 67 | first_frame_condition_mode=self.config.unet_additional_kwargs[ 68 | "first_frame_condition_mode" 69 | ], 70 | use_frame_stride_condition=self.config.unet_additional_kwargs[ 71 | "use_frame_stride_condition" 72 | ], 73 | use_safetensors=True, 74 | ) 75 | 76 | self.pipeline = ConditionalAnimationPipeline( 77 | vae=vae, 78 | text_encoder=text_encoder, 79 | tokenizer=tokenizer, 80 | unet=unet, 81 | scheduler=noise_scheduler, 82 | ).to("cuda") 83 | 84 | def predict( 85 | self, 86 | image: Path = Input(description="Input image as the first frame of the video."), 87 | prompt: str = Input( 88 | description="Input prompt", 89 | default="An astronaut riding a rainbow unicorn", 90 | ), 91 | negative_prompt: str = Input( 92 | description="Input Negative Prompt", 93 | default="", 94 | ), 95 | num_inference_steps: int = Input( 96 | description="Number of denoising steps", ge=1, le=500, default=50 97 | ), 98 | text_guidance_scale: float = Input( 99 | description="Scale for classifier-free guidance from the text", 100 | ge=1, 101 | le=50, 102 | default=7.5, 103 | ), 104 | image_guidance_scale: float = Input( 105 | description="Scale for classifier-free guidance from the image", default=1.0 106 | ), 107 | seed: int = Input( 108 | description="Random seed. Leave blank to randomize the seed", default=None 109 | ), 110 | ) -> Path: 111 | """Run a single prediction on the model""" 112 | if seed is None: 113 | seed = int.from_bytes(os.urandom(2), "big") 114 | print(f"Using seed: {seed}") 115 | torch.manual_seed(seed) 116 | 117 | if self.config.frameinit_kwargs.enable: 118 | self.pipeline.init_filter( 119 | width=self.config.sampling_kwargs.width, 120 | height=self.config.sampling_kwargs.height, 121 | video_length=self.config.sampling_kwargs.n_frames, 122 | filter_params=self.config.frameinit_kwargs.filter_params, 123 | ) 124 | 125 | sample = self.pipeline( 126 | prompt, 127 | negative_prompt=negative_prompt, 128 | first_frame_paths=str(image), 129 | num_inference_steps=num_inference_steps, 130 | guidance_scale_txt=text_guidance_scale, 131 | guidance_scale_img=image_guidance_scale, 132 | width=self.config.sampling_kwargs.width, # output video only supports 16 frames of 256x256 133 | height=self.config.sampling_kwargs.height, 134 | video_length=self.config.sampling_kwargs.n_frames, 135 | noise_sampling_method=self.config.unet_additional_kwargs[ 136 | "noise_sampling_method" 137 | ], 138 | noise_alpha=float(self.config.unet_additional_kwargs["noise_alpha"]), 139 | eta=self.config.sampling_kwargs.ddim_eta, 140 | frame_stride=self.config.sampling_kwargs.frame_stride, 141 | guidance_rescale=self.config.sampling_kwargs.guidance_rescale, 142 | num_videos_per_prompt=self.config.sampling_kwargs.num_videos_per_prompt, 143 | use_frameinit=self.config.frameinit_kwargs.enable, 144 | frameinit_noise_level=self.config.frameinit_kwargs.noise_level, 145 | camera_motion=self.config.frameinit_kwargs.camera_motion, 146 | ).videos 147 | out_path = "/tmp/out.mp4" 148 | save_videos_grid(sample, out_path, format="mp4") 149 | return Path(out_path) 150 | 151 | 152 | def save_weights(): 153 | "Load the weights, saved to local and push to replicate.delivery" 154 | inference_config = "configs/inference/inference.yaml" 155 | config = OmegaConf.load(inference_config) 156 | 157 | tokenizer = CLIPTokenizer.from_pretrained( 158 | config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True 159 | ) 160 | tokenizer.save_pretrained("ConsistI2V_cache/tokenizer") 161 | text_encoder = CLIPTextModel.from_pretrained( 162 | config.pretrained_model_path, subfolder="text_encoder" 163 | ) 164 | text_encoder.save_pretrained( 165 | "ConsistI2V_cache/text_encoder", safe_serialization=True 166 | ) 167 | vae = AutoencoderKL.from_pretrained( 168 | config.pretrained_model_path, subfolder="vae", use_safetensors=True 169 | ) 170 | vae.save_pretrained("ConsistI2V_cache/vae", safe_serialization=True) 171 | unet = VideoLDMUNet3DConditionModel.from_pretrained( 172 | config.pretrained_model_path, 173 | subfolder="unet", 174 | variant=config.unet_additional_kwargs["variant"], 175 | temp_pos_embedding=config.unet_additional_kwargs["temp_pos_embedding"], 176 | augment_temporal_attention=config.unet_additional_kwargs[ 177 | "augment_temporal_attention" 178 | ], 179 | use_temporal=True, 180 | n_frames=config.sampling_kwargs["n_frames"], 181 | n_temp_heads=config.unet_additional_kwargs["n_temp_heads"], 182 | first_frame_condition_mode=config.unet_additional_kwargs[ 183 | "first_frame_condition_mode" 184 | ], 185 | use_frame_stride_condition=config.unet_additional_kwargs[ 186 | "use_frame_stride_condition" 187 | ], 188 | use_safetensors=True, 189 | ) 190 | unet.save_pretrained("ConsistI2V_cache/unet", safe_serialization=True) 191 | -------------------------------------------------------------------------------- /scripts/animate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import random 4 | import os 5 | import logging 6 | from omegaconf import OmegaConf 7 | 8 | import torch 9 | 10 | import diffusers 11 | from diffusers import AutoencoderKL, DDIMScheduler 12 | 13 | from transformers import CLIPTextModel, CLIPTokenizer 14 | 15 | from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel 16 | from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline 17 | from consisti2v.utils.util import save_videos_grid 18 | from diffusers.utils.import_utils import is_xformers_available 19 | 20 | def main(args, config): 21 | logging.basicConfig( 22 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 23 | datefmt="%m/%d/%Y %H:%M:%S", 24 | level=logging.INFO, 25 | ) 26 | diffusers.utils.logging.set_verbosity_info() 27 | 28 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 29 | savedir = f"{config.output_dir}/{config.output_name}-{time_str}" 30 | os.makedirs(savedir) 31 | 32 | samples = [] 33 | sample_idx = 0 34 | 35 | ### >>> create validation pipeline >>> ### 36 | if config.pipeline_pretrained_path is None: 37 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(config.noise_scheduler_kwargs)) 38 | tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True) 39 | text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") 40 | vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae", use_safetensors=True) 41 | unet = VideoLDMUNet3DConditionModel.from_pretrained( 42 | config.pretrained_model_path, 43 | subfolder="unet", 44 | variant=config.unet_additional_kwargs['variant'], 45 | temp_pos_embedding=config.unet_additional_kwargs['temp_pos_embedding'], 46 | augment_temporal_attention=config.unet_additional_kwargs['augment_temporal_attention'], 47 | use_temporal=True, 48 | n_frames=config.sampling_kwargs['n_frames'], 49 | n_temp_heads=config.unet_additional_kwargs['n_temp_heads'], 50 | first_frame_condition_mode=config.unet_additional_kwargs['first_frame_condition_mode'], 51 | use_frame_stride_condition=config.unet_additional_kwargs['use_frame_stride_condition'], 52 | use_safetensors=True 53 | ) 54 | 55 | # 1. unet ckpt 56 | if config.unet_path is not None: 57 | if os.path.isdir(config.unet_path): 58 | unet_dict = VideoLDMUNet3DConditionModel.from_pretrained(config.unet_path) 59 | m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False) 60 | assert len(u) == 0 61 | del unet_dict 62 | else: 63 | checkpoint_dict = torch.load(config.unet_path, map_location="cpu") 64 | state_dict = checkpoint_dict["state_dict"] if "state_dict" in checkpoint_dict else checkpoint_dict 65 | if config.unet_ckpt_prefix is not None: 66 | state_dict = {k.replace(config.unet_ckpt_prefix, ''): v for k, v in state_dict.items()} 67 | m, u = unet.load_state_dict(state_dict, strict=False) 68 | assert len(u) == 0 69 | 70 | if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: 71 | unet.enable_xformers_memory_efficient_attention() 72 | 73 | pipeline = ConditionalAnimationPipeline( 74 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=noise_scheduler) 75 | 76 | else: 77 | pipeline = ConditionalAnimationPipeline.from_pretrained(config.pipeline_pretrained_path) 78 | 79 | pipeline.to("cuda") 80 | 81 | # (frameinit) initialize frequency filter for noise reinitialization ------------- 82 | if config.frameinit_kwargs.enable: 83 | pipeline.init_filter( 84 | width = config.sampling_kwargs.width, 85 | height = config.sampling_kwargs.height, 86 | video_length = config.sampling_kwargs.n_frames, 87 | filter_params = config.frameinit_kwargs.filter_params, 88 | ) 89 | # ------------------------------------------------------------------------------- 90 | ### <<< create validation pipeline <<< ### 91 | 92 | if args.prompt is not None: 93 | prompts = [args.prompt] 94 | n_prompts = [args.n_prompt] 95 | first_frame_paths = [args.path_to_first_frame] 96 | random_seeds = [int(args.seed)] if args.seed != "random" else "random" 97 | else: 98 | prompt_config = OmegaConf.load(args.prompt_config) 99 | prompts = prompt_config.prompts 100 | n_prompts = list(prompt_config.n_prompts) * len(prompts) if len(prompt_config.n_prompts) == 1 else prompt_config.n_prompts 101 | first_frame_paths = prompt_config.path_to_first_frames 102 | random_seeds = prompt_config.seeds 103 | 104 | if random_seeds == "random": 105 | random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))] 106 | else: 107 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) 108 | random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds 109 | 110 | config.prompt_kwargs = OmegaConf.create({"random_seeds": [], "prompts": prompts, "n_prompts": n_prompts, "first_frame_paths": first_frame_paths}) 111 | for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate(zip(prompts, n_prompts, first_frame_paths, random_seeds)): 112 | # manually set random seed for reproduction 113 | if random_seed != -1: torch.manual_seed(random_seed) 114 | else: torch.seed() 115 | config.prompt_kwargs.random_seeds.append(torch.initial_seed()) 116 | 117 | print(f"current seed: {torch.initial_seed()}") 118 | print(f"sampling {prompt} ...") 119 | sample = pipeline( 120 | prompt, 121 | negative_prompt = n_prompt, 122 | first_frame_paths = first_frame_path, 123 | num_inference_steps = config.sampling_kwargs.steps, 124 | guidance_scale_txt = config.sampling_kwargs.guidance_scale_txt, 125 | guidance_scale_img = config.sampling_kwargs.guidance_scale_img, 126 | width = config.sampling_kwargs.width, 127 | height = config.sampling_kwargs.height, 128 | video_length = config.sampling_kwargs.n_frames, 129 | noise_sampling_method = config.unet_additional_kwargs['noise_sampling_method'], 130 | noise_alpha = float(config.unet_additional_kwargs['noise_alpha']), 131 | eta = config.sampling_kwargs.ddim_eta, 132 | frame_stride = config.sampling_kwargs.frame_stride, 133 | guidance_rescale = config.sampling_kwargs.guidance_rescale, 134 | num_videos_per_prompt = config.sampling_kwargs.num_videos_per_prompt, 135 | use_frameinit = config.frameinit_kwargs.enable, 136 | frameinit_noise_level = config.frameinit_kwargs.noise_level, 137 | camera_motion = config.frameinit_kwargs.camera_motion, 138 | ).videos 139 | samples.append(sample) 140 | 141 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "") 142 | if sample.shape[0] > 1: 143 | for cnt, samp in enumerate(sample): 144 | save_videos_grid(samp.unsqueeze(0), f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}", format=args.format) 145 | else: 146 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}", format=args.format) 147 | print(f"save to {savedir}/sample/{prompt}.{args.format}") 148 | 149 | sample_idx += 1 150 | 151 | samples = torch.concat(samples) 152 | save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format) 153 | 154 | OmegaConf.save(config, f"{savedir}/config.yaml") 155 | 156 | if args.save_model: 157 | pipeline.save_pretrained(f"{savedir}/model") 158 | 159 | 160 | if __name__ == "__main__": 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml") 163 | parser.add_argument("--prompt", "-p", type=str, default=None) 164 | parser.add_argument("--n_prompt", "-n", type=str, default="") 165 | parser.add_argument("--seed", type=str, default="random") 166 | parser.add_argument("--path_to_first_frame", "-f", type=str, default=None) 167 | parser.add_argument("--prompt_config", type=str, default="configs/prompts/default.yaml") 168 | parser.add_argument("--format", type=str, default="mp4", choices=["gif", "mp4"]) 169 | parser.add_argument("--save_model", action="store_true") 170 | parser.add_argument("optional_args", nargs='*', default=[]) 171 | args = parser.parse_args() 172 | 173 | config = OmegaConf.load(args.inference_config) 174 | 175 | if args.optional_args: 176 | modified_config = OmegaConf.from_dotlist(args.optional_args) 177 | config = OmegaConf.merge(config, modified_config) 178 | 179 | main(args, config) 180 | -------------------------------------------------------------------------------- /scripts/animate_autoregress.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import random 4 | import os 5 | import logging 6 | from omegaconf import OmegaConf 7 | 8 | import torch 9 | 10 | import diffusers 11 | from diffusers import AutoencoderKL, DDIMScheduler 12 | 13 | from transformers import CLIPTextModel, CLIPTokenizer 14 | 15 | from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel 16 | from consisti2v.pipelines.pipeline_autoregress_animation import AutoregressiveAnimationPipeline 17 | from consisti2v.utils.util import save_videos_grid 18 | from diffusers.utils.import_utils import is_xformers_available 19 | 20 | def main(args, config): 21 | logging.basicConfig( 22 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 23 | datefmt="%m/%d/%Y %H:%M:%S", 24 | level=logging.INFO, 25 | ) 26 | diffusers.utils.logging.set_verbosity_info() 27 | 28 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 29 | savedir = f"{config.output_dir}/{config.output_name}-{time_str}" 30 | os.makedirs(savedir) 31 | 32 | samples = [] 33 | sample_idx = 0 34 | 35 | ### >>> create validation pipeline >>> ### 36 | if config.pipeline_pretrained_path is None: 37 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(config.noise_scheduler_kwargs)) 38 | tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True) 39 | text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") 40 | vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae", use_safetensors=True) 41 | unet = VideoLDMUNet3DConditionModel.from_pretrained( 42 | config.pretrained_model_path, 43 | subfolder="unet", 44 | variant=config.unet_additional_kwargs['variant'], 45 | temp_pos_embedding=config.unet_additional_kwargs['temp_pos_embedding'], 46 | augment_temporal_attention=config.unet_additional_kwargs['augment_temporal_attention'], 47 | use_temporal=True, 48 | n_frames=config.sampling_kwargs['n_frames'], 49 | n_temp_heads=config.unet_additional_kwargs['n_temp_heads'], 50 | first_frame_condition_mode=config.unet_additional_kwargs['first_frame_condition_mode'], 51 | use_frame_stride_condition=config.unet_additional_kwargs['use_frame_stride_condition'], 52 | use_safetensors=True 53 | ) 54 | 55 | params_unet = [p.numel() for n, p in unet.named_parameters()] 56 | params_vae = [p.numel() for n, p in vae.named_parameters()] 57 | params_text_encoder = [p.numel() for n, p in text_encoder.named_parameters()] 58 | params = params_unet + params_vae + params_text_encoder 59 | print(f"### UNet Parameters: {sum(params) / 1e6} M") 60 | 61 | # 1. unet ckpt 62 | if config.unet_path is not None: 63 | if os.path.isdir(config.unet_path): 64 | unet_dict = VideoLDMUNet3DConditionModel.from_pretrained(config.unet_path) 65 | m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False) 66 | assert len(u) == 0 67 | del unet_dict 68 | else: 69 | checkpoint_dict = torch.load(config.unet_path, map_location="cpu") 70 | state_dict = checkpoint_dict["state_dict"] if "state_dict" in checkpoint_dict else checkpoint_dict 71 | if config.unet_ckpt_prefix is not None: 72 | state_dict = {k.replace(config.unet_ckpt_prefix, ''): v for k, v in state_dict.items()} 73 | m, u = unet.load_state_dict(state_dict, strict=False) 74 | assert len(u) == 0 75 | 76 | if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: 77 | unet.enable_xformers_memory_efficient_attention() 78 | 79 | pipeline = AutoregressiveAnimationPipeline( 80 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=noise_scheduler) 81 | 82 | else: 83 | pipeline = AutoregressiveAnimationPipeline.from_pretrained(config.pipeline_pretrained_path) 84 | 85 | pipeline.to("cuda") 86 | 87 | # (frameinit) initialize frequency filter for noise reinitialization ------------- 88 | if config.frameinit_kwargs.enable: 89 | pipeline.init_filter( 90 | width = config.sampling_kwargs.width, 91 | height = config.sampling_kwargs.height, 92 | video_length = config.sampling_kwargs.n_frames, 93 | filter_params = config.frameinit_kwargs.filter_params, 94 | ) 95 | # ------------------------------------------------------------------------------- 96 | ### <<< create validation pipeline <<< ### 97 | 98 | if args.prompt is not None: 99 | prompts = [args.prompt] 100 | n_prompts = [args.n_prompt] 101 | first_frame_paths = [args.path_to_first_frame] 102 | random_seeds = [int(args.seed)] if args.seed != "random" else "random" 103 | else: 104 | prompt_config = OmegaConf.load(args.prompt_config) 105 | prompts = prompt_config.prompts 106 | n_prompts = list(prompt_config.n_prompts) * len(prompts) if len(prompt_config.n_prompts) == 1 else prompt_config.n_prompts 107 | first_frame_paths = prompt_config.path_to_first_frames 108 | random_seeds = prompt_config.seeds 109 | 110 | if random_seeds == "random": 111 | random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))] 112 | else: 113 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) 114 | random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds 115 | 116 | config.prompt_kwargs = OmegaConf.create({"random_seeds": [], "prompts": prompts, "n_prompts": n_prompts, "first_frame_paths": first_frame_paths}) 117 | for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate(zip(prompts, n_prompts, first_frame_paths, random_seeds)): 118 | # manually set random seed for reproduction 119 | if random_seed != -1: torch.manual_seed(random_seed) 120 | else: torch.seed() 121 | config.prompt_kwargs.random_seeds.append(torch.initial_seed()) 122 | 123 | print(f"current seed: {torch.initial_seed()}") 124 | print(f"sampling {prompt} ...") 125 | sample = pipeline( 126 | prompt, 127 | negative_prompt = n_prompt, 128 | first_frame_paths = first_frame_path, 129 | num_inference_steps = config.sampling_kwargs.steps, 130 | guidance_scale_txt = config.sampling_kwargs.guidance_scale_txt, 131 | guidance_scale_img = config.sampling_kwargs.guidance_scale_img, 132 | width = config.sampling_kwargs.width, 133 | height = config.sampling_kwargs.height, 134 | video_length = config.sampling_kwargs.n_frames, 135 | noise_sampling_method = config.unet_additional_kwargs['noise_sampling_method'], 136 | noise_alpha = float(config.unet_additional_kwargs['noise_alpha']), 137 | eta = config.sampling_kwargs.ddim_eta, 138 | frame_stride = config.sampling_kwargs.frame_stride, 139 | guidance_rescale = config.sampling_kwargs.guidance_rescale, 140 | num_videos_per_prompt = config.sampling_kwargs.num_videos_per_prompt, 141 | autoregress_steps = config.sampling_kwargs.autoregress_steps, 142 | use_frameinit = config.frameinit_kwargs.enable, 143 | frameinit_noise_level = config.frameinit_kwargs.noise_level, 144 | ).videos 145 | samples.append(sample) 146 | 147 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "") 148 | if sample.shape[0] > 1: 149 | for cnt, samp in enumerate(sample): 150 | save_videos_grid(samp.unsqueeze(0), f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}", format=args.format) 151 | else: 152 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}", format=args.format) 153 | print(f"save to {savedir}/sample/{prompt}.{args.format}") 154 | 155 | sample_idx += 1 156 | 157 | samples = torch.concat(samples) 158 | save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format) 159 | 160 | OmegaConf.save(config, f"{savedir}/config.yaml") 161 | 162 | if args.save_model: 163 | pipeline.save_pretrained(f"{savedir}/model") 164 | 165 | 166 | if __name__ == "__main__": 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--inference_config", type=str, default="configs/inference/inference_autoregress.yaml") 169 | parser.add_argument("--prompt", "-p", type=str, default=None) 170 | parser.add_argument("--n_prompt", "-n", type=str, default="") 171 | parser.add_argument("--seed", type=str, default="random") 172 | parser.add_argument("--path_to_first_frame", "-f", type=str, default=None) 173 | parser.add_argument("--prompt_config", type=str, default="configs/prompts/default.yaml") 174 | parser.add_argument("--format", type=str, default="gif", choices=["gif", "mp4"]) 175 | parser.add_argument("--save_model", action="store_true") 176 | parser.add_argument("optional_args", nargs='*', default=[]) 177 | args = parser.parse_args() 178 | 179 | config = OmegaConf.load(args.inference_config) 180 | 181 | if args.optional_args: 182 | modified_config = OmegaConf.from_dotlist(args.optional_args) 183 | config = OmegaConf.merge(config, modified_config) 184 | 185 | main(args, config) 186 | -------------------------------------------------------------------------------- /consisti2v/models/rotary_embedding.py: -------------------------------------------------------------------------------- 1 | from math import pi, log 2 | 3 | import torch 4 | from torch.nn import Module, ModuleList 5 | from torch.cuda.amp import autocast 6 | from torch import nn, einsum, broadcast_tensors, Tensor 7 | 8 | from einops import rearrange, repeat 9 | 10 | from beartype import beartype 11 | from beartype.typing import Literal, Union, Optional 12 | 13 | # helper functions 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def default(val, d): 19 | return val if exists(val) else d 20 | 21 | # broadcat, as tortoise-tts was using it 22 | 23 | def broadcat(tensors, dim = -1): 24 | broadcasted_tensors = broadcast_tensors(*tensors) 25 | return torch.cat(broadcasted_tensors, dim = dim) 26 | 27 | # rotary embedding helper functions 28 | 29 | def rotate_half(x): 30 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 31 | x1, x2 = x.unbind(dim = -1) 32 | x = torch.stack((-x2, x1), dim = -1) 33 | return rearrange(x, '... d r -> ... (d r)') 34 | 35 | @autocast(enabled = False) 36 | def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2): 37 | if t.ndim == 3: 38 | seq_len = t.shape[seq_dim] 39 | freqs = freqs[-seq_len:].to(t) 40 | 41 | rot_dim = freqs.shape[-1] 42 | end_index = start_index + rot_dim 43 | 44 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 45 | 46 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 47 | t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) 48 | return torch.cat((t_left, t, t_right), dim = -1) 49 | 50 | # learned rotation helpers 51 | 52 | def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): 53 | if exists(freq_ranges): 54 | rotations = einsum('..., f -> ... f', rotations, freq_ranges) 55 | rotations = rearrange(rotations, '... r f -> ... (r f)') 56 | 57 | rotations = repeat(rotations, '... n -> ... (n r)', r = 2) 58 | return apply_rotary_emb(rotations, t, start_index = start_index) 59 | 60 | # classes 61 | 62 | class RotaryEmbedding(Module): 63 | @beartype 64 | def __init__( 65 | self, 66 | dim, 67 | custom_freqs: Optional[Tensor] = None, 68 | freqs_for: Union[ 69 | Literal['lang'], 70 | Literal['pixel'], 71 | Literal['constant'] 72 | ] = 'lang', 73 | theta = 10000, 74 | max_freq = 10, 75 | num_freqs = 1, 76 | learned_freq = False, 77 | use_xpos = False, 78 | xpos_scale_base = 512, 79 | interpolate_factor = 1., 80 | theta_rescale_factor = 1., 81 | seq_before_head_dim = False 82 | ): 83 | super().__init__() 84 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 85 | # has some connection to NTK literature 86 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 87 | 88 | theta *= theta_rescale_factor ** (dim / (dim - 2)) 89 | 90 | self.freqs_for = freqs_for 91 | 92 | if exists(custom_freqs): 93 | freqs = custom_freqs 94 | elif freqs_for == 'lang': 95 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 96 | elif freqs_for == 'pixel': 97 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 98 | elif freqs_for == 'constant': 99 | freqs = torch.ones(num_freqs).float() 100 | 101 | self.tmp_store('cached_freqs', None) 102 | self.tmp_store('cached_scales', None) 103 | 104 | self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) 105 | 106 | self.learned_freq = learned_freq 107 | 108 | # dummy for device 109 | 110 | self.tmp_store('dummy', torch.tensor(0)) 111 | 112 | # default sequence dimension 113 | 114 | self.seq_before_head_dim = seq_before_head_dim 115 | self.default_seq_dim = -3 if seq_before_head_dim else -2 116 | 117 | # interpolation factors 118 | 119 | assert interpolate_factor >= 1. 120 | self.interpolate_factor = interpolate_factor 121 | 122 | # xpos 123 | 124 | self.use_xpos = use_xpos 125 | if not use_xpos: 126 | self.tmp_store('scale', None) 127 | return 128 | 129 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) 130 | self.scale_base = xpos_scale_base 131 | self.tmp_store('scale', scale) 132 | 133 | @property 134 | def device(self): 135 | return self.dummy.device 136 | 137 | def tmp_store(self, key, value): 138 | self.register_buffer(key, value, persistent = False) 139 | 140 | def get_seq_pos(self, seq_len, device, dtype, offset = 0): 141 | return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor 142 | 143 | def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None, seq_pos = None): 144 | seq_dim = default(seq_dim, self.default_seq_dim) 145 | 146 | assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' 147 | 148 | device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] 149 | 150 | if exists(freq_seq_len): 151 | assert freq_seq_len >= seq_len 152 | seq_len = freq_seq_len 153 | 154 | if seq_pos is None: 155 | seq_pos = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) 156 | else: 157 | assert seq_pos.shape[0] == seq_len 158 | 159 | freqs = self.forward(seq_pos, seq_len = seq_len, offset = offset) 160 | 161 | if seq_dim == -3: 162 | freqs = rearrange(freqs, 'n d -> n 1 d') 163 | 164 | return apply_rotary_emb(freqs, t, seq_dim = seq_dim) 165 | 166 | def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): 167 | seq_dim = default(seq_dim, self.default_seq_dim) 168 | 169 | q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] 170 | assert q_len <= k_len 171 | rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len) 172 | rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim) 173 | 174 | rotated_q = rotated_q.type(q.dtype) 175 | rotated_k = rotated_k.type(k.dtype) 176 | 177 | return rotated_q, rotated_k 178 | 179 | def rotate_queries_and_keys(self, q, k, seq_dim = None): 180 | seq_dim = default(seq_dim, self.default_seq_dim) 181 | 182 | assert self.use_xpos 183 | device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] 184 | 185 | seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) 186 | 187 | freqs = self.forward(seq, seq_len = seq_len) 188 | scale = self.get_scale(seq, seq_len = seq_len).to(dtype) 189 | 190 | if seq_dim == -3: 191 | freqs = rearrange(freqs, 'n d -> n 1 d') 192 | scale = rearrange(scale, 'n d -> n 1 d') 193 | 194 | rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) 195 | rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) 196 | 197 | rotated_q = rotated_q.type(q.dtype) 198 | rotated_k = rotated_k.type(k.dtype) 199 | 200 | return rotated_q, rotated_k 201 | 202 | @beartype 203 | def get_scale( 204 | self, 205 | t: Tensor, 206 | seq_len: Optional[int] = None, 207 | offset = 0 208 | ): 209 | assert self.use_xpos 210 | 211 | should_cache = exists(seq_len) 212 | 213 | if ( 214 | should_cache and \ 215 | exists(self.cached_scales) and \ 216 | (seq_len + offset) <= self.cached_scales.shape[0] 217 | ): 218 | return self.cached_scales[offset:(offset + seq_len)] 219 | 220 | scale = 1. 221 | if self.use_xpos: 222 | power = (t - len(t) // 2) / self.scale_base 223 | scale = self.scale ** rearrange(power, 'n -> n 1') 224 | scale = torch.cat((scale, scale), dim = -1) 225 | 226 | if should_cache: 227 | self.tmp_store('cached_scales', scale) 228 | 229 | return scale 230 | 231 | def get_axial_freqs(self, *dims): 232 | Colon = slice(None) 233 | all_freqs = [] 234 | 235 | for ind, dim in enumerate(dims): 236 | if self.freqs_for == 'pixel': 237 | pos = torch.linspace(-1, 1, steps = dim, device = self.device) 238 | else: 239 | pos = torch.arange(dim, device = self.device) 240 | 241 | freqs = self.forward(pos, seq_len = dim) 242 | 243 | all_axis = [None] * len(dims) 244 | all_axis[ind] = Colon 245 | 246 | new_axis_slice = (Ellipsis, *all_axis, Colon) 247 | all_freqs.append(freqs[new_axis_slice]) 248 | 249 | all_freqs = broadcast_tensors(*all_freqs) 250 | return torch.cat(all_freqs, dim = -1) 251 | 252 | @autocast(enabled = False) 253 | def forward( 254 | self, 255 | t: Tensor, 256 | seq_len = None, 257 | offset = 0 258 | ): 259 | # should_cache = ( 260 | # not self.learned_freq and \ 261 | # exists(seq_len) and \ 262 | # self.freqs_for != 'pixel' 263 | # ) 264 | 265 | # if ( 266 | # should_cache and \ 267 | # exists(self.cached_freqs) and \ 268 | # (offset + seq_len) <= self.cached_freqs.shape[0] 269 | # ): 270 | # return self.cached_freqs[offset:(offset + seq_len)].detach() 271 | 272 | freqs = self.freqs 273 | 274 | freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) 275 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 276 | 277 | # if should_cache: 278 | # self.tmp_store('cached_freqs', freqs.detach()) 279 | 280 | return freqs 281 | -------------------------------------------------------------------------------- /consisti2v/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os, io, csv, math, random 2 | import json 3 | import numpy as np 4 | from einops import rearrange 5 | from decord import VideoReader 6 | 7 | import torch 8 | import torchvision.transforms as transforms 9 | from torch.utils.data.dataset import Dataset 10 | 11 | from diffusers.utils import logging 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | class WebVid10M(Dataset): 16 | def __init__( 17 | self, 18 | json_path, video_folder=None, 19 | sample_size=256, sample_stride=4, sample_n_frames=16, 20 | is_image=False, 21 | **kwargs, 22 | ): 23 | logger.info(f"loading annotations from {json_path} ...") 24 | with open(json_path, 'rb') as json_file: 25 | json_list = list(json_file) 26 | self.dataset = [json.loads(json_str) for json_str in json_list] 27 | self.length = len(self.dataset) 28 | logger.info(f"data scale: {self.length}") 29 | 30 | self.video_folder = video_folder 31 | self.sample_stride = sample_stride if isinstance(sample_stride, int) else tuple(sample_stride) 32 | self.sample_n_frames = sample_n_frames 33 | self.is_image = is_image 34 | 35 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 36 | self.pixel_transforms = transforms.Compose([ 37 | transforms.RandomHorizontalFlip(), 38 | transforms.Resize(sample_size[0], antialias=None), 39 | transforms.CenterCrop(sample_size), 40 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 41 | ]) 42 | 43 | def get_batch(self, idx): 44 | video_dict = self.dataset[idx] 45 | video_relative_path, name = video_dict['file'], video_dict['text'] 46 | 47 | if self.video_folder is not None: 48 | if video_relative_path[0] == '/': 49 | video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path)) 50 | else: 51 | video_dir = os.path.join(self.video_folder, video_relative_path) 52 | else: 53 | video_dir = video_relative_path 54 | video_reader = VideoReader(video_dir) 55 | video_length = len(video_reader) 56 | 57 | if not self.is_image: 58 | if isinstance(self.sample_stride, int): 59 | stride = self.sample_stride 60 | elif isinstance(self.sample_stride, tuple): 61 | stride = random.randint(self.sample_stride[0], self.sample_stride[1]) 62 | clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1) 63 | start_idx = random.randint(0, video_length - clip_length) 64 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 65 | else: 66 | frame_difference = random.randint(2, self.sample_n_frames) 67 | clip_length = min(video_length, (frame_difference - 1) * self.sample_stride + 1) 68 | start_idx = random.randint(0, video_length - clip_length) 69 | batch_index = [start_idx, start_idx + clip_length - 1] 70 | 71 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 72 | pixel_values = pixel_values / 255. 73 | del video_reader 74 | 75 | return pixel_values, name 76 | 77 | def __len__(self): 78 | return self.length 79 | 80 | def __getitem__(self, idx): 81 | while True: 82 | try: 83 | pixel_values, name = self.get_batch(idx) 84 | break 85 | 86 | except Exception as e: 87 | idx = random.randint(0, self.length-1) 88 | 89 | pixel_values = self.pixel_transforms(pixel_values) 90 | sample = dict(pixel_values=pixel_values, text=name) 91 | return sample 92 | 93 | 94 | class Pexels(Dataset): 95 | def __init__( 96 | self, 97 | json_path, caption_json_path, video_folder=None, 98 | sample_size=256, sample_duration=1, sample_fps=8, 99 | is_image=False, 100 | **kwargs, 101 | ): 102 | logger.info(f"loading captions from {caption_json_path} ...") 103 | with open(caption_json_path, 'rb') as caption_json_file: 104 | caption_json_list = list(caption_json_file) 105 | self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list} 106 | 107 | logger.info(f"loading annotations from {json_path} ...") 108 | with open(json_path, 'rb') as json_file: 109 | json_list = list(json_file) 110 | dataset = [json.loads(json_str) for json_str in json_list] 111 | 112 | self.dataset = [] 113 | for data in dataset: 114 | data['text'] = self.caption_dict[data['id']] 115 | if data['height'] / data['width'] < 0.625: 116 | self.dataset.append(data) 117 | self.length = len(self.dataset) 118 | logger.info(f"data scale: {self.length}") 119 | 120 | self.video_folder = video_folder 121 | self.sample_duration = sample_duration 122 | self.sample_fps = sample_fps 123 | self.sample_n_frames = sample_duration * sample_fps 124 | self.is_image = is_image 125 | 126 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 127 | self.pixel_transforms = transforms.Compose([ 128 | transforms.RandomHorizontalFlip(), 129 | transforms.Resize(sample_size[0], antialias=None), 130 | transforms.CenterCrop(sample_size), 131 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 132 | ]) 133 | 134 | def get_batch(self, idx): 135 | video_dict = self.dataset[idx] 136 | video_relative_path, name = video_dict['file'], video_dict['text'] 137 | fps = video_dict['fps'] 138 | 139 | if self.video_folder is not None: 140 | if video_relative_path[0] == '/': 141 | video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path)) 142 | else: 143 | video_dir = os.path.join(self.video_folder, video_relative_path) 144 | else: 145 | video_dir = video_relative_path 146 | video_reader = VideoReader(video_dir) 147 | video_length = len(video_reader) 148 | 149 | if not self.is_image: 150 | clip_length = min(video_length, math.ceil(fps * self.sample_duration)) 151 | start_idx = random.randint(0, video_length - clip_length) 152 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 153 | else: 154 | frame_difference = random.randint(2, self.sample_n_frames) 155 | sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1) 156 | clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1) 157 | start_idx = random.randint(0, video_length - clip_length) 158 | batch_index = [start_idx, start_idx + clip_length - 1] 159 | 160 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 161 | pixel_values = pixel_values / 255. 162 | del video_reader 163 | 164 | return pixel_values, name 165 | 166 | def __len__(self): 167 | return self.length 168 | 169 | def __getitem__(self, idx): 170 | while True: 171 | try: 172 | pixel_values, name = self.get_batch(idx) 173 | break 174 | 175 | except Exception as e: 176 | idx = random.randint(0, self.length-1) 177 | 178 | pixel_values = self.pixel_transforms(pixel_values) 179 | sample = dict(pixel_values=pixel_values, text=name) 180 | return sample 181 | 182 | 183 | class JointDataset(Dataset): 184 | def __init__( 185 | self, 186 | webvid_config, pexels_config, 187 | sample_size=256, 188 | sample_duration=None, sample_fps=None, sample_stride=None, sample_n_frames=None, 189 | is_image=False, 190 | **kwargs, 191 | ): 192 | assert (sample_duration is None and sample_fps is None) or (sample_duration is not None and sample_fps is not None), "sample_duration and sample_fps should be both None or not None" 193 | if sample_duration is not None and sample_fps is not None: 194 | assert sample_stride is None, "when sample_duration and sample_fps are not None, sample_stride should be None" 195 | if sample_stride is not None: 196 | assert sample_fps is None and sample_duration is None, "when sample_stride is not None, sample_duration and sample_fps should be both None" 197 | 198 | self.dataset = [] 199 | 200 | if pexels_config.enable: 201 | logger.info(f"loading pexels dataset") 202 | logger.info(f"loading captions from {pexels_config.caption_json_path} ...") 203 | with open(pexels_config.caption_json_path, 'rb') as caption_json_file: 204 | caption_json_list = list(caption_json_file) 205 | self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list} 206 | 207 | logger.info(f"loading annotations from {pexels_config.json_path} ...") 208 | with open(pexels_config.json_path, 'rb') as json_file: 209 | json_list = list(json_file) 210 | dataset = [json.loads(json_str) for json_str in json_list] 211 | 212 | for data in dataset: 213 | data['text'] = self.caption_dict[data['id']] 214 | data['dataset'] = 'pexels' 215 | if data['height'] / data['width'] < 0.625: 216 | self.dataset.append(data) 217 | 218 | if webvid_config.enable: 219 | logger.info(f"loading webvid dataset") 220 | logger.info(f"loading annotations from {webvid_config.json_path} ...") 221 | with open(webvid_config.json_path, 'rb') as json_file: 222 | json_list = list(json_file) 223 | dataset = [json.loads(json_str) for json_str in json_list] 224 | for data in dataset: 225 | data['dataset'] = 'webvid' 226 | self.dataset.extend(dataset) 227 | 228 | self.length = len(self.dataset) 229 | logger.info(f"data scale: {self.length}") 230 | 231 | self.pexels_folder = pexels_config.video_folder 232 | self.webvid_folder = webvid_config.video_folder 233 | self.sample_duration = sample_duration 234 | self.sample_fps = sample_fps 235 | self.sample_n_frames = sample_duration * sample_fps if sample_n_frames is None else sample_n_frames 236 | self.sample_stride = sample_stride if (sample_stride is None) or (sample_stride is not None and isinstance(sample_stride, int)) else tuple(sample_stride) 237 | self.is_image = is_image 238 | 239 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 240 | self.pixel_transforms = transforms.Compose([ 241 | transforms.RandomHorizontalFlip(), 242 | transforms.Resize(sample_size[0], antialias=None), 243 | transforms.CenterCrop(sample_size), 244 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 245 | ]) 246 | 247 | def get_batch(self, idx): 248 | video_dict = self.dataset[idx] 249 | video_relative_path, name = video_dict['file'], video_dict['text'] 250 | 251 | if video_dict['dataset'] == 'pexels': 252 | video_folder = self.pexels_folder 253 | elif video_dict['dataset'] == 'webvid': 254 | video_folder = self.webvid_folder 255 | else: 256 | raise NotImplementedError 257 | 258 | if video_folder is not None: 259 | if video_relative_path[0] == '/': 260 | video_dir = os.path.join(video_folder, os.path.basename(video_relative_path)) 261 | else: 262 | video_dir = os.path.join(video_folder, video_relative_path) 263 | else: 264 | video_dir = video_relative_path 265 | video_reader = VideoReader(video_dir) 266 | video_length = len(video_reader) 267 | 268 | stride = None 269 | if not self.is_image: 270 | if self.sample_duration is not None: 271 | fps = video_dict['fps'] 272 | clip_length = min(video_length, math.ceil(fps * self.sample_duration)) 273 | elif self.sample_stride is not None: 274 | if isinstance(self.sample_stride, int): 275 | stride = self.sample_stride 276 | elif isinstance(self.sample_stride, tuple): 277 | stride = random.randint(self.sample_stride[0], self.sample_stride[1]) 278 | clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1) 279 | 280 | start_idx = random.randint(0, video_length - clip_length) 281 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 282 | 283 | else: 284 | frame_difference = random.randint(2, self.sample_n_frames) 285 | if self.sample_duration is not None: 286 | fps = video_dict['fps'] 287 | sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1) 288 | elif self.sample_stride is not None: 289 | sample_stride = self.sample_stride 290 | 291 | clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1) 292 | start_idx = random.randint(0, video_length - clip_length) 293 | batch_index = [start_idx, start_idx + clip_length - 1] 294 | 295 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 296 | pixel_values = pixel_values / 255. 297 | del video_reader 298 | 299 | return pixel_values, name, stride 300 | 301 | def __len__(self): 302 | return self.length 303 | 304 | def __getitem__(self, idx): 305 | while True: 306 | try: 307 | pixel_values, name, stride = self.get_batch(idx) 308 | break 309 | 310 | except Exception as e: 311 | idx = random.randint(0, self.length-1) 312 | 313 | pixel_values = self.pixel_transforms(pixel_values) 314 | sample = dict(pixel_values=pixel_values, text=name, stride=stride) 315 | return sample 316 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import torch 5 | import random 6 | import requests 7 | from PIL import Image 8 | import numpy as np 9 | 10 | import gradio as gr 11 | from datetime import datetime 12 | 13 | import torchvision.transforms as T 14 | 15 | from diffusers import DDIMScheduler 16 | from diffusers.utils.import_utils import is_xformers_available 17 | from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline 18 | from consisti2v.utils.util import save_videos_grid 19 | from omegaconf import OmegaConf 20 | 21 | 22 | sample_idx = 0 23 | scheduler_dict = { 24 | "DDIM": DDIMScheduler, 25 | } 26 | 27 | css = """ 28 | .toolbutton { 29 | margin-buttom: 0em 0em 0em 0em; 30 | max-width: 2.5em; 31 | min-width: 2.5em !important; 32 | height: 2.5em; 33 | } 34 | """ 35 | 36 | class AnimateController: 37 | def __init__(self): 38 | 39 | # config dirs 40 | self.basedir = os.getcwd() 41 | self.savedir = os.path.join(self.basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) 42 | self.savedir_sample = os.path.join(self.savedir, "sample") 43 | os.makedirs(self.savedir, exist_ok=True) 44 | 45 | self.image_resolution = (256, 256) 46 | # config models 47 | self.pipeline = ConditionalAnimationPipeline.from_pretrained("TIGER-Lab/ConsistI2V", torch_dtype=torch.float16,) 48 | self.pipeline.to("cuda") 49 | 50 | def update_textbox_and_save_image(self, input_image, height_slider, width_slider, center_crop): 51 | pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB") 52 | img_path = os.path.join(self.savedir, "input_image.png") 53 | pil_image.save(img_path) 54 | self.image_resolution = pil_image.size 55 | original_width, original_height = pil_image.size 56 | if center_crop: 57 | crop_aspect_ratio = width_slider / height_slider 58 | aspect_ratio = original_width / original_height 59 | if aspect_ratio > crop_aspect_ratio: 60 | new_width = int(crop_aspect_ratio * original_height) 61 | left = (original_width - new_width) / 2 62 | top = 0 63 | right = left + new_width 64 | bottom = original_height 65 | pil_image = pil_image.crop((left, top, right, bottom)) 66 | elif aspect_ratio < crop_aspect_ratio: 67 | new_height = int(original_width / crop_aspect_ratio) 68 | top = (original_height - new_height) / 2 69 | left = 0 70 | right = original_width 71 | bottom = top + new_height 72 | pil_image = pil_image.crop((left, top, right, bottom)) 73 | 74 | pil_image = pil_image.resize((width_slider, height_slider)) 75 | return gr.Textbox.update(value=img_path), gr.Image.update(value=np.array(pil_image)) 76 | 77 | def animate( 78 | self, 79 | prompt_textbox, 80 | negative_prompt_textbox, 81 | input_image_path, 82 | sampler_dropdown, 83 | sample_step_slider, 84 | width_slider, 85 | height_slider, 86 | txt_cfg_scale_slider, 87 | img_cfg_scale_slider, 88 | center_crop, 89 | frame_stride, 90 | use_frameinit, 91 | frame_init_noise_level, 92 | seed_textbox 93 | ): 94 | if self.pipeline is None: 95 | raise gr.Error(f"Please select a pretrained pipeline path.") 96 | if input_image_path == "": 97 | raise gr.Error(f"Please upload an input image.") 98 | if (not center_crop) and (width_slider % 8 != 0 or height_slider % 8 != 0): 99 | raise gr.Error(f"`height` and `width` have to be divisible by 8 but are {height_slider} and {width_slider}.") 100 | if center_crop and (width_slider % 8 != 0 or height_slider % 8 != 0): 101 | raise gr.Error(f"`height` and `width` (after cropping) have to be divisible by 8 but are {height_slider} and {width_slider}.") 102 | 103 | if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: self.pipeline.unet.enable_xformers_memory_efficient_attention() 104 | 105 | if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) 106 | else: torch.seed() 107 | seed = torch.initial_seed() 108 | 109 | if input_image_path.startswith("http://") or input_image_path.startswith("https://"): 110 | first_frame = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB') 111 | else: 112 | first_frame = Image.open(input_image_path).convert('RGB') 113 | 114 | original_width, original_height = first_frame.size 115 | 116 | if not center_crop: 117 | img_transform = T.Compose([ 118 | T.ToTensor(), 119 | T.Resize((height_slider, width_slider), antialias=None), 120 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 121 | ]) 122 | else: 123 | aspect_ratio = original_width / original_height 124 | crop_aspect_ratio = width_slider / height_slider 125 | if aspect_ratio > crop_aspect_ratio: 126 | center_crop_width = int(crop_aspect_ratio * original_height) 127 | center_crop_height = original_height 128 | elif aspect_ratio < crop_aspect_ratio: 129 | center_crop_width = original_width 130 | center_crop_height = int(original_width / crop_aspect_ratio) 131 | else: 132 | center_crop_width = original_width 133 | center_crop_height = original_height 134 | img_transform = T.Compose([ 135 | T.ToTensor(), 136 | T.CenterCrop((center_crop_height, center_crop_width)), 137 | T.Resize((height_slider, width_slider), antialias=None), 138 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 139 | ]) 140 | 141 | first_frame = img_transform(first_frame).unsqueeze(0) 142 | first_frame = first_frame.to("cuda") 143 | 144 | if use_frameinit: 145 | self.pipeline.init_filter( 146 | width = width_slider, 147 | height = height_slider, 148 | video_length = 16, 149 | filter_params = OmegaConf.create({'method': 'gaussian', 'd_s': 0.25, 'd_t': 0.25,}) 150 | ) 151 | 152 | 153 | sample = self.pipeline( 154 | prompt_textbox, 155 | negative_prompt = negative_prompt_textbox, 156 | first_frames = first_frame, 157 | num_inference_steps = sample_step_slider, 158 | guidance_scale_txt = txt_cfg_scale_slider, 159 | guidance_scale_img = img_cfg_scale_slider, 160 | width = width_slider, 161 | height = height_slider, 162 | video_length = 16, 163 | noise_sampling_method = "pyoco_mixed", 164 | noise_alpha = 1.0, 165 | frame_stride = frame_stride, 166 | use_frameinit = use_frameinit, 167 | frameinit_noise_level = frame_init_noise_level, 168 | camera_motion = None, 169 | ).videos 170 | 171 | global sample_idx 172 | sample_idx += 1 173 | save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4") 174 | save_videos_grid(sample, save_sample_path, format="mp4") 175 | 176 | sample_config = { 177 | "prompt": prompt_textbox, 178 | "n_prompt": negative_prompt_textbox, 179 | "first_frame_path": input_image_path, 180 | "sampler": sampler_dropdown, 181 | "num_inference_steps": sample_step_slider, 182 | "guidance_scale_text": txt_cfg_scale_slider, 183 | "guidance_scale_image": img_cfg_scale_slider, 184 | "width": width_slider, 185 | "height": height_slider, 186 | "video_length": 8, 187 | "seed": seed 188 | } 189 | json_str = json.dumps(sample_config, indent=4) 190 | with open(os.path.join(self.savedir, "logs.json"), "a") as f: 191 | f.write(json_str) 192 | f.write("\n\n") 193 | 194 | return gr.Video.update(value=save_sample_path) 195 | 196 | 197 | controller = AnimateController() 198 | 199 | 200 | def ui(): 201 | with gr.Blocks(css=css) as demo: 202 | gr.Markdown( 203 | """ 204 | # ConsistI2V Text+Image to Video Generation 205 | Input image will be used as the first frame of the video. Text prompts will be used to control the output video content. 206 | """ 207 | ) 208 | 209 | with gr.Column(variant="panel"): 210 | gr.Markdown( 211 | """ 212 | - Input image can be specified using the "Input Image Path/URL" text box (this can be either a local image path or an image URL) or uploaded by clicking or dragging the image to the "Input Image" box. The uploaded image will be temporarily stored in the "samples/Gradio" folder under the project root folder. 213 | - Input image can be resized and/or center cropped to a given resolution by adjusting the "Width" and "Height" sliders. It is recommended to use the same resolution as the training resolution (256x256). 214 | - After setting the input image path or changed the width/height of the input image, press the "Preview" button to visualize the resized input image. 215 | """ 216 | ) 217 | 218 | with gr.Row(): 219 | prompt_textbox = gr.Textbox(label="Prompt", lines=2) 220 | negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2) 221 | 222 | with gr.Row().style(equal_height=False): 223 | with gr.Column(): 224 | with gr.Row(): 225 | sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) 226 | sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1) 227 | 228 | with gr.Row(): 229 | center_crop = gr.Checkbox(label="Center Crop the Image", value=True) 230 | width_slider = gr.Slider(label="Width", value=256, minimum=0, maximum=512, step=64) 231 | height_slider = gr.Slider(label="Height", value=256, minimum=0, maximum=512, step=64) 232 | with gr.Row(): 233 | txt_cfg_scale_slider = gr.Slider(label="Text CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.5) 234 | img_cfg_scale_slider = gr.Slider(label="Image CFG Scale", value=1.0, minimum=1.0, maximum=20.0, step=0.5) 235 | frame_stride = gr.Slider(label="Frame Stride", value=3, minimum=1, maximum=5, step=1) 236 | 237 | with gr.Row(): 238 | use_frameinit = gr.Checkbox(label="Enable FrameInit", value=True) 239 | frameinit_noise_level = gr.Slider(label="FrameInit Noise Level", value=850, minimum=1, maximum=999, step=1) 240 | 241 | 242 | seed_textbox = gr.Textbox(label="Seed", value=-1) 243 | seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") 244 | seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) 245 | 246 | 247 | 248 | generate_button = gr.Button(value="Generate", variant='primary') 249 | 250 | with gr.Column(): 251 | with gr.Row(): 252 | input_image_path = gr.Textbox(label="Input Image Path/URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.") 253 | preview_button = gr.Button(value="Preview") 254 | 255 | with gr.Row(): 256 | input_image = gr.Image(label="Input Image", interactive=True) 257 | input_image.upload(fn=controller.update_textbox_and_save_image, inputs=[input_image, height_slider, width_slider, center_crop], outputs=[input_image_path, input_image]) 258 | result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True) 259 | 260 | def update_and_resize_image(input_image_path, height_slider, width_slider, center_crop): 261 | if input_image_path.startswith("http://") or input_image_path.startswith("https://"): 262 | pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB') 263 | else: 264 | pil_image = Image.open(input_image_path).convert('RGB') 265 | controller.image_resolution = pil_image.size 266 | original_width, original_height = pil_image.size 267 | 268 | if center_crop: 269 | crop_aspect_ratio = width_slider / height_slider 270 | aspect_ratio = original_width / original_height 271 | if aspect_ratio > crop_aspect_ratio: 272 | new_width = int(crop_aspect_ratio * original_height) 273 | left = (original_width - new_width) / 2 274 | top = 0 275 | right = left + new_width 276 | bottom = original_height 277 | pil_image = pil_image.crop((left, top, right, bottom)) 278 | elif aspect_ratio < crop_aspect_ratio: 279 | new_height = int(original_width / crop_aspect_ratio) 280 | top = (original_height - new_height) / 2 281 | left = 0 282 | right = original_width 283 | bottom = top + new_height 284 | pil_image = pil_image.crop((left, top, right, bottom)) 285 | 286 | pil_image = pil_image.resize((width_slider, height_slider)) 287 | return gr.Image.update(value=np.array(pil_image)) 288 | 289 | preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image]) 290 | input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image]) 291 | 292 | generate_button.click( 293 | fn=controller.animate, 294 | inputs=[ 295 | prompt_textbox, 296 | negative_prompt_textbox, 297 | input_image_path, 298 | sampler_dropdown, 299 | sample_step_slider, 300 | width_slider, 301 | height_slider, 302 | txt_cfg_scale_slider, 303 | img_cfg_scale_slider, 304 | center_crop, 305 | frame_stride, 306 | use_frameinit, 307 | frameinit_noise_level, 308 | seed_textbox, 309 | ], 310 | outputs=[result_video] 311 | ) 312 | 313 | return demo 314 | 315 | 316 | if __name__ == "__main__": 317 | demo = ui() 318 | demo.launch(share=True) 319 | -------------------------------------------------------------------------------- /consisti2v/models/videoldm_transformer_blocks.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/huggingface/diffusers/blob/v0.21.0/src/diffusers/models/transformer_2d.py 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Optional 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from einops import rearrange, repeat 10 | 11 | from diffusers.configuration_utils import ConfigMixin, register_to_config 12 | from diffusers.models.embeddings import ImagePositionalEmbeddings 13 | from diffusers.utils import BaseOutput, deprecate 14 | from diffusers.models.attention import AdaLayerNorm, AdaLayerNormZero, FeedForward, GatedSelfAttentionDense 15 | from diffusers.models.embeddings import PatchEmbed 16 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 17 | from diffusers.models.modeling_utils import ModelMixin 18 | from diffusers.models.transformer_2d import Transformer2DModelOutput 19 | from diffusers.utils.torch_utils import maybe_allow_in_graph 20 | from diffusers.models.attention_processor import Attention 21 | from diffusers.models.lora import LoRACompatibleLinear 22 | 23 | from .videoldm_attention import ConditionalAttention, TemporalConditionalAttention 24 | 25 | 26 | class Transformer2DConditionModel(ModelMixin, ConfigMixin): 27 | @register_to_config 28 | def __init__( 29 | self, 30 | num_attention_heads: int = 16, 31 | attention_head_dim: int = 88, 32 | in_channels: Optional[int] = None, 33 | out_channels: Optional[int] = None, 34 | num_layers: int = 1, 35 | dropout: float = 0.0, 36 | norm_num_groups: int = 32, 37 | cross_attention_dim: Optional[int] = None, 38 | attention_bias: bool = False, 39 | sample_size: Optional[int] = None, 40 | num_vector_embeds: Optional[int] = None, 41 | patch_size: Optional[int] = None, 42 | activation_fn: str = "geglu", 43 | num_embeds_ada_norm: Optional[int] = None, 44 | use_linear_projection: bool = False, 45 | only_cross_attention: bool = False, 46 | double_self_attention: bool = False, 47 | upcast_attention: bool = False, 48 | norm_type: str = "layer_norm", 49 | norm_elementwise_affine: bool = True, 50 | attention_type: str = "default", 51 | # additional 52 | n_frames: int = 8, 53 | is_temporal: bool = False, 54 | augment_temporal_attention: bool = False, 55 | rotary_emb=False, 56 | ): 57 | super().__init__() 58 | self.use_linear_projection = use_linear_projection 59 | self.num_attention_heads = num_attention_heads 60 | self.attention_head_dim = attention_head_dim 61 | inner_dim = num_attention_heads * attention_head_dim 62 | 63 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 64 | # Define whether input is continuous or discrete depending on configuration 65 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 66 | self.is_input_vectorized = num_vector_embeds is not None 67 | self.is_input_patches = in_channels is not None and patch_size is not None 68 | 69 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 70 | deprecation_message = ( 71 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 72 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 73 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 74 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 75 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 76 | ) 77 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 78 | norm_type = "ada_norm" 79 | 80 | if self.is_input_continuous and self.is_input_vectorized: 81 | raise ValueError( 82 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 83 | " sure that either `in_channels` or `num_vector_embeds` is None." 84 | ) 85 | elif self.is_input_vectorized and self.is_input_patches: 86 | raise ValueError( 87 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 88 | " sure that either `num_vector_embeds` or `num_patches` is None." 89 | ) 90 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 91 | raise ValueError( 92 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 93 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 94 | ) 95 | 96 | # 2. Define input layers 97 | if self.is_input_continuous: 98 | self.in_channels = in_channels 99 | 100 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 101 | if use_linear_projection: 102 | self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) 103 | else: 104 | self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 105 | elif self.is_input_vectorized: 106 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" 107 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" 108 | 109 | self.height = sample_size 110 | self.width = sample_size 111 | self.num_vector_embeds = num_vector_embeds 112 | self.num_latent_pixels = self.height * self.width 113 | 114 | self.latent_image_embedding = ImagePositionalEmbeddings( 115 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width 116 | ) 117 | elif self.is_input_patches: 118 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 119 | 120 | self.height = sample_size 121 | self.width = sample_size 122 | 123 | self.patch_size = patch_size 124 | self.pos_embed = PatchEmbed( 125 | height=sample_size, 126 | width=sample_size, 127 | patch_size=patch_size, 128 | in_channels=in_channels, 129 | embed_dim=inner_dim, 130 | ) 131 | 132 | # 3. Define transformers blocks 133 | self.transformer_blocks = nn.ModuleList( 134 | [ 135 | BasicConditionalTransformerBlock( 136 | inner_dim, 137 | num_attention_heads, 138 | attention_head_dim, 139 | dropout=dropout, 140 | cross_attention_dim=cross_attention_dim, 141 | activation_fn=activation_fn, 142 | num_embeds_ada_norm=num_embeds_ada_norm, 143 | attention_bias=attention_bias, 144 | only_cross_attention=only_cross_attention, 145 | double_self_attention=double_self_attention, 146 | upcast_attention=upcast_attention, 147 | norm_type=norm_type, 148 | norm_elementwise_affine=norm_elementwise_affine, 149 | attention_type=attention_type, 150 | # additional 151 | n_frames=n_frames, 152 | is_temporal=is_temporal, 153 | augment_temporal_attention=augment_temporal_attention, 154 | rotary_emb=rotary_emb, 155 | ) 156 | for d in range(num_layers) 157 | ] 158 | ) 159 | 160 | # 4. Define output layers 161 | self.out_channels = in_channels if out_channels is None else out_channels 162 | if self.is_input_continuous: 163 | # TODO: should use out_channels for continuous projections 164 | if use_linear_projection: 165 | self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) 166 | else: 167 | self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 168 | elif self.is_input_vectorized: 169 | self.norm_out = nn.LayerNorm(inner_dim) 170 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) 171 | elif self.is_input_patches: 172 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 173 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) 174 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 175 | 176 | self.alpha = None 177 | if is_temporal: 178 | self.alpha = nn.Parameter(torch.ones(1)) 179 | 180 | self.gradient_checkpointing = False 181 | 182 | def forward( 183 | self, 184 | hidden_states: torch.Tensor, 185 | encoder_hidden_states: Optional[torch.Tensor] = None, 186 | timestep: Optional[torch.LongTensor] = None, 187 | class_labels: Optional[torch.LongTensor] = None, 188 | cross_attention_kwargs: Dict[str, Any] = None, 189 | attention_mask: Optional[torch.Tensor] = None, 190 | encoder_attention_mask: Optional[torch.Tensor] = None, 191 | return_dict: bool = True, 192 | condition_on_first_frame: bool = False, 193 | ): 194 | input_states = hidden_states 195 | input_height, input_width = hidden_states.shape[-2:] 196 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 197 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 198 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 199 | # expects mask of shape: 200 | # [batch, key_tokens] 201 | # adds singleton query_tokens dimension: 202 | # [batch, 1, key_tokens] 203 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 204 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 205 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 206 | if attention_mask is not None and attention_mask.ndim == 2: 207 | # assume that mask is expressed as: 208 | # (1 = keep, 0 = discard) 209 | # convert mask into a bias that can be added to attention scores: 210 | # (keep = +0, discard = -10000.0) 211 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 212 | attention_mask = attention_mask.unsqueeze(1) 213 | 214 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 215 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 216 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 217 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 218 | 219 | # Retrieve lora scale. 220 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 221 | 222 | # 1. Input 223 | if self.is_input_continuous: 224 | batch, _, height, width = hidden_states.shape 225 | residual = hidden_states 226 | 227 | hidden_states = self.norm(hidden_states) 228 | if not self.use_linear_projection: 229 | hidden_states = self.proj_in(hidden_states, lora_scale) 230 | inner_dim = hidden_states.shape[1] 231 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 232 | else: 233 | inner_dim = hidden_states.shape[1] 234 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 235 | hidden_states = self.proj_in(hidden_states, scale=lora_scale) 236 | 237 | elif self.is_input_vectorized: 238 | hidden_states = self.latent_image_embedding(hidden_states) 239 | elif self.is_input_patches: 240 | hidden_states = self.pos_embed(hidden_states) 241 | 242 | # 2. Blocks 243 | for block in self.transformer_blocks: 244 | if self.training and self.gradient_checkpointing: 245 | hidden_states = torch.utils.checkpoint.checkpoint( 246 | block, 247 | hidden_states, 248 | attention_mask, 249 | encoder_hidden_states, 250 | encoder_attention_mask, 251 | timestep, 252 | cross_attention_kwargs, 253 | class_labels, 254 | use_reentrant=False, 255 | ) 256 | else: 257 | hidden_states = block( 258 | hidden_states, 259 | attention_mask=attention_mask, 260 | encoder_hidden_states=encoder_hidden_states, 261 | encoder_attention_mask=encoder_attention_mask, 262 | timestep=timestep, 263 | cross_attention_kwargs=cross_attention_kwargs, 264 | class_labels=class_labels, 265 | # additional 266 | condition_on_first_frame=condition_on_first_frame, 267 | input_height=input_height, 268 | input_width=input_width, 269 | ) 270 | 271 | # 3. Output 272 | if self.is_input_continuous: 273 | if not self.use_linear_projection: 274 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 275 | hidden_states = self.proj_out(hidden_states, scale=lora_scale) 276 | else: 277 | hidden_states = self.proj_out(hidden_states, scale=lora_scale) 278 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 279 | 280 | output = hidden_states + residual 281 | elif self.is_input_vectorized: 282 | hidden_states = self.norm_out(hidden_states) 283 | logits = self.out(hidden_states) 284 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 285 | logits = logits.permute(0, 2, 1) 286 | 287 | # log(p(x_0)) 288 | output = F.log_softmax(logits.double(), dim=1).float() 289 | elif self.is_input_patches: 290 | # TODO: cleanup! 291 | conditioning = self.transformer_blocks[0].norm1.emb( 292 | timestep, class_labels, hidden_dtype=hidden_states.dtype 293 | ) 294 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 295 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 296 | hidden_states = self.proj_out_2(hidden_states) 297 | 298 | # unpatchify 299 | height = width = int(hidden_states.shape[1] ** 0.5) 300 | hidden_states = hidden_states.reshape( 301 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 302 | ) 303 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 304 | output = hidden_states.reshape( 305 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 306 | ) 307 | 308 | if self.alpha is not None: 309 | with torch.no_grad(): 310 | self.alpha.clamp_(0, 1) 311 | 312 | output = self.alpha * input_states + (1 - self.alpha) * output 313 | 314 | if not return_dict: 315 | return (output,) 316 | 317 | return Transformer2DModelOutput(sample=output) 318 | 319 | 320 | @maybe_allow_in_graph 321 | class BasicConditionalTransformerBlock(nn.Module): 322 | """ transformer block with first frame conditioning """ 323 | def __init__( 324 | self, 325 | dim: int, 326 | num_attention_heads: int, 327 | attention_head_dim: int, 328 | dropout=0.0, 329 | cross_attention_dim: Optional[int] = None, 330 | activation_fn: str = "geglu", 331 | num_embeds_ada_norm: Optional[int] = None, 332 | attention_bias: bool = False, 333 | only_cross_attention: bool = False, 334 | double_self_attention: bool = False, 335 | upcast_attention: bool = False, 336 | norm_elementwise_affine: bool = True, 337 | norm_type: str = "layer_norm", 338 | final_dropout: bool = False, 339 | attention_type: str = "default", 340 | # additional 341 | n_frames: int = 8, 342 | is_temporal: bool = False, 343 | augment_temporal_attention: bool = False, 344 | rotary_emb=False, 345 | ): 346 | super().__init__() 347 | self.n_frames = n_frames 348 | self.only_cross_attention = only_cross_attention 349 | self.augment_temporal_attention = augment_temporal_attention 350 | self.is_temporal = is_temporal 351 | 352 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 353 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 354 | 355 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 356 | raise ValueError( 357 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 358 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 359 | ) 360 | 361 | # Define 3 blocks. Each block has its own normalization layer. 362 | # 1. Self-Attn 363 | if self.use_ada_layer_norm: 364 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 365 | elif self.use_ada_layer_norm_zero: 366 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 367 | else: 368 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 369 | 370 | if not is_temporal: 371 | self.attn1 = ConditionalAttention( 372 | query_dim=dim, 373 | heads=num_attention_heads, 374 | dim_head=attention_head_dim, 375 | dropout=dropout, 376 | bias=attention_bias, 377 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 378 | upcast_attention=upcast_attention, 379 | ) 380 | else: 381 | self.attn1 = TemporalConditionalAttention( 382 | query_dim=dim, 383 | heads=num_attention_heads, 384 | dim_head=attention_head_dim, 385 | dropout=dropout, 386 | bias=attention_bias, 387 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 388 | upcast_attention=upcast_attention, 389 | # additional 390 | n_frames=n_frames, 391 | rotary_emb=rotary_emb, 392 | ) 393 | 394 | # 2. Cross-Attn 395 | if cross_attention_dim is not None or double_self_attention: 396 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 397 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 398 | # the second cross attention block. 399 | self.norm2 = ( 400 | AdaLayerNorm(dim, num_embeds_ada_norm) 401 | if self.use_ada_layer_norm 402 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 403 | ) 404 | if not is_temporal: 405 | self.attn2 = ConditionalAttention( 406 | query_dim=dim, 407 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 408 | heads=num_attention_heads, 409 | dim_head=attention_head_dim, 410 | dropout=dropout, 411 | bias=attention_bias, 412 | upcast_attention=upcast_attention, 413 | ) # is self-attn if encoder_hidden_states is none 414 | else: 415 | self.attn2 = TemporalConditionalAttention( 416 | query_dim=dim, 417 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 418 | heads=num_attention_heads, 419 | dim_head=attention_head_dim, 420 | dropout=dropout, 421 | bias=attention_bias, 422 | upcast_attention=upcast_attention, 423 | # additional 424 | n_frames=n_frames, 425 | rotary_emb=rotary_emb, 426 | ) 427 | else: 428 | self.norm2 = None 429 | self.attn2 = None 430 | 431 | # 3. Feed-forward 432 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 433 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 434 | 435 | # 4. Fuser 436 | if attention_type == "gated" or attention_type == "gated-text-image": 437 | self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) 438 | 439 | # let chunk size default to None 440 | self._chunk_size = None 441 | self._chunk_dim = 0 442 | 443 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): 444 | # Sets chunk feed-forward 445 | self._chunk_size = chunk_size 446 | self._chunk_dim = dim 447 | 448 | def forward( 449 | self, 450 | hidden_states: torch.FloatTensor, 451 | attention_mask: Optional[torch.FloatTensor] = None, 452 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 453 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 454 | timestep: Optional[torch.LongTensor] = None, 455 | cross_attention_kwargs: Dict[str, Any] = None, 456 | class_labels: Optional[torch.LongTensor] = None, 457 | condition_on_first_frame: bool = False, 458 | input_height: Optional[int] = None, 459 | input_width: Optional[int] = None, 460 | ): 461 | # Notice that normalization is always applied before the real computation in the following blocks. 462 | # 0. Self-Attention 463 | if self.use_ada_layer_norm: 464 | norm_hidden_states = self.norm1(hidden_states, timestep) 465 | elif self.use_ada_layer_norm_zero: 466 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 467 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 468 | ) 469 | else: 470 | norm_hidden_states = self.norm1(hidden_states) 471 | 472 | # 1. Retrieve lora scale. 473 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 474 | 475 | # 2. Prepare GLIGEN inputs 476 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 477 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 478 | 479 | if condition_on_first_frame: 480 | first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :] 481 | first_frame_hidden_states = repeat(first_frame_hidden_states, 'b d h -> b f d h', f=self.n_frames) 482 | first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b f d h -> (b f) d h') 483 | first_frame_concat_hidden_states = torch.cat((norm_hidden_states, first_frame_hidden_states), dim=1) 484 | attn_output = self.attn1( 485 | norm_hidden_states, 486 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else first_frame_concat_hidden_states, 487 | attention_mask=attention_mask, 488 | **cross_attention_kwargs, 489 | ) 490 | elif self.is_temporal and self.augment_temporal_attention: 491 | first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :] 492 | first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b (h w) c -> b h w c', h=input_height, w=input_width) 493 | first_frame_hidden_states = first_frame_hidden_states.permute(0, 3, 1, 2) 494 | padded_first_frame = torch.nn.functional.pad(first_frame_hidden_states, (1, 1, 1, 1), "replicate") 495 | first_frame_windows = padded_first_frame.unfold(2, 3, 1).unfold(3, 3, 1) 496 | mask = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.bool) 497 | adjacent_slices = first_frame_windows[:, :, :, :, mask] 498 | attn_output = self.attn1( 499 | norm_hidden_states, 500 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 501 | attention_mask=attention_mask, 502 | adjacent_slices=adjacent_slices, 503 | **cross_attention_kwargs, 504 | ) 505 | else: 506 | attn_output = self.attn1( 507 | norm_hidden_states, 508 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 509 | attention_mask=attention_mask, 510 | **cross_attention_kwargs, 511 | ) 512 | if self.use_ada_layer_norm_zero: 513 | attn_output = gate_msa.unsqueeze(1) * attn_output 514 | hidden_states = attn_output + hidden_states 515 | 516 | # 2.5 GLIGEN Control 517 | if gligen_kwargs is not None: 518 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 519 | # 2.5 ends 520 | 521 | # 3. Cross-Attention 522 | if self.attn2 is not None: 523 | norm_hidden_states = ( 524 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 525 | ) 526 | 527 | attn_output = self.attn2( 528 | norm_hidden_states, 529 | encoder_hidden_states=encoder_hidden_states, 530 | attention_mask=encoder_attention_mask, 531 | **cross_attention_kwargs, 532 | ) 533 | hidden_states = attn_output + hidden_states 534 | 535 | # 4. Feed-forward 536 | norm_hidden_states = self.norm3(hidden_states) 537 | 538 | if self.use_ada_layer_norm_zero: 539 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 540 | 541 | if self._chunk_size is not None: 542 | # "feed_forward_chunk_size" can be used to save memory 543 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: 544 | raise ValueError( 545 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 546 | ) 547 | 548 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size 549 | ff_output = torch.cat( 550 | [ 551 | self.ff(hid_slice, scale=lora_scale) 552 | for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) 553 | ], 554 | dim=self._chunk_dim, 555 | ) 556 | else: 557 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 558 | 559 | if self.use_ada_layer_norm_zero: 560 | ff_output = gate_mlp.unsqueeze(1) * ff_output 561 | 562 | hidden_states = ff_output + hidden_states 563 | 564 | return hidden_states -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import wandb 4 | import random 5 | import time 6 | import logging 7 | import inspect 8 | import argparse 9 | import datetime 10 | import numpy as np 11 | 12 | from pathlib import Path 13 | from tqdm.auto import tqdm 14 | from einops import rearrange, repeat 15 | from omegaconf import OmegaConf 16 | from typing import Dict, Optional, Tuple 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | import diffusers 22 | from diffusers import AutoencoderKL, DDIMScheduler 23 | from diffusers.optimization import get_scheduler 24 | from diffusers.utils import check_min_version 25 | from diffusers.utils.import_utils import is_xformers_available 26 | from diffusers.training_utils import EMAModel 27 | 28 | import transformers 29 | from transformers import CLIPTextModel, CLIPTokenizer 30 | 31 | from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs 32 | from accelerate.logging import get_logger 33 | from accelerate.utils import set_seed 34 | 35 | from consisti2v.data.dataset import WebVid10M, Pexels, JointDataset 36 | from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel 37 | from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline 38 | from consisti2v.utils.util import save_videos_grid 39 | 40 | logger = get_logger(__name__, log_level="INFO") 41 | 42 | def main( 43 | name: str, 44 | use_wandb: bool, 45 | 46 | is_image: bool, 47 | 48 | output_dir: str, 49 | pretrained_model_path: str, 50 | 51 | train_data: Dict, 52 | validation_data: Dict, 53 | 54 | cfg_random_null_text_ratio: float = 0.1, 55 | cfg_random_null_img_ratio: float = 0.0, 56 | 57 | resume_from_checkpoint: Optional[str] = None, 58 | unet_additional_kwargs: Dict = {}, 59 | use_ema: bool = False, 60 | ema_decay: float = 0.9999, 61 | noise_scheduler_kwargs = None, 62 | 63 | max_train_epoch: int = -1, 64 | max_train_steps: int = 100, 65 | validation_steps: int = 100, 66 | 67 | learning_rate: float = 3e-5, 68 | scale_lr: bool = False, 69 | lr_warmup_steps: int = 0, 70 | lr_scheduler: str = "constant", 71 | 72 | trainable_modules: Tuple[str] = (None, ), 73 | num_workers: int = 32, 74 | train_batch_size: int = 1, 75 | adam_beta1: float = 0.9, 76 | adam_beta2: float = 0.999, 77 | adam_weight_decay: float = 1e-2, 78 | adam_epsilon: float = 1e-08, 79 | max_grad_norm: float = 1.0, 80 | gradient_accumulation_steps: int = 1, 81 | gradient_checkpointing: bool = False, 82 | checkpointing_epochs: int = 5, 83 | checkpointing_steps: int = -1, 84 | 85 | mixed_precision: Optional[str] = "fp16", 86 | enable_xformers_memory_efficient_attention: bool = True, 87 | 88 | seed: Optional[int] = 42, 89 | is_debug: bool = False, 90 | ): 91 | check_min_version("0.10.0.dev0") 92 | *_, config = inspect.getargvalues(inspect.currentframe()) 93 | config = {k: v for k, v in config.items() if k != 'config' and k != '_'} 94 | 95 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True if not is_image else False) 96 | init_kwargs = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=3600)) 97 | 98 | accelerator = Accelerator( 99 | gradient_accumulation_steps=gradient_accumulation_steps, 100 | mixed_precision=mixed_precision, 101 | kwargs_handlers=[ddp_kwargs, init_kwargs], 102 | ) 103 | 104 | if seed is not None: 105 | set_seed(seed) 106 | 107 | # Logging folder 108 | folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S") 109 | output_dir = os.path.join(output_dir, folder_name) 110 | if is_debug and os.path.exists(output_dir): 111 | os.system(f"rm -rf {output_dir}") 112 | 113 | # Make one log on every process with the configuration for debugging. 114 | logging.basicConfig( 115 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 116 | datefmt="%m/%d/%Y %H:%M:%S", 117 | level=logging.INFO, 118 | ) 119 | logger.info(accelerator.state, main_process_only=False) 120 | 121 | if accelerator.is_local_main_process: 122 | transformers.utils.logging.set_verbosity_warning() 123 | diffusers.utils.logging.set_verbosity_info() 124 | else: 125 | transformers.utils.logging.set_verbosity_error() 126 | diffusers.utils.logging.set_verbosity_error() 127 | 128 | if accelerator.is_main_process and (not is_debug) and use_wandb: 129 | project_name = "text_image_to_video" if not is_image else "image_finetune" 130 | wandb.init(project=project_name, name=folder_name, config=config) 131 | accelerator.wait_for_everyone() 132 | 133 | # Handle the output folder creation 134 | if accelerator.is_main_process: 135 | os.makedirs(output_dir, exist_ok=True) 136 | os.makedirs(f"{output_dir}/samples", exist_ok=True) 137 | os.makedirs(f"{output_dir}/sanity_check", exist_ok=True) 138 | os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) 139 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) 140 | 141 | # TODO: change all datasets to fps+duration in the future 142 | if train_data.dataset == "pexels": 143 | train_data.sample_n_frames = train_data.sample_duration * train_data.sample_fps 144 | elif train_data.dataset == "joint": 145 | if train_data.sample_duration is not None: 146 | train_data.sample_n_frames = train_data.sample_duration * train_data.sample_fps 147 | # Load scheduler, tokenizer and models. 148 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 149 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 150 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 151 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 152 | unet = VideoLDMUNet3DConditionModel.from_pretrained( 153 | pretrained_model_path, 154 | subfolder="unet", 155 | variant=unet_additional_kwargs['variant'], 156 | use_temporal=True if not is_image else False, 157 | temp_pos_embedding=unet_additional_kwargs['temp_pos_embedding'], 158 | augment_temporal_attention=unet_additional_kwargs['augment_temporal_attention'], 159 | n_frames=train_data.sample_n_frames if not is_image else 2, 160 | n_temp_heads=unet_additional_kwargs['n_temp_heads'], 161 | first_frame_condition_mode=unet_additional_kwargs['first_frame_condition_mode'], 162 | use_frame_stride_condition=unet_additional_kwargs['use_frame_stride_condition'], 163 | use_safetensors=True 164 | ) 165 | 166 | # Freeze vae and text_encoder 167 | vae.requires_grad_(False) 168 | text_encoder.requires_grad_(False) 169 | unet.train() 170 | 171 | if use_ema: 172 | ema_unet = VideoLDMUNet3DConditionModel.from_pretrained( 173 | pretrained_model_path, 174 | subfolder="unet", 175 | variant=unet_additional_kwargs['variant'], 176 | use_temporal=True if not is_image else False, 177 | temp_pos_embedding=unet_additional_kwargs['temp_pos_embedding'], 178 | augment_temporal_attention=unet_additional_kwargs['augment_temporal_attention'], 179 | n_frames=train_data.sample_n_frames if not is_image else 2, 180 | n_temp_heads=unet_additional_kwargs['n_temp_heads'], 181 | first_frame_condition_mode=unet_additional_kwargs['first_frame_condition_mode'], 182 | use_frame_stride_condition=unet_additional_kwargs['use_frame_stride_condition'], 183 | use_safetensors=True 184 | ) 185 | ema_unet = EMAModel(ema_unet.parameters(), decay=ema_decay, model_cls=VideoLDMUNet3DConditionModel, model_config=ema_unet.config) 186 | 187 | # Set unet trainable parameters 188 | train_all_parameters = False 189 | for trainable_module_name in trainable_modules: 190 | if trainable_module_name == 'all': 191 | unet.requires_grad_(True) 192 | train_all_parameters = True 193 | break 194 | 195 | if not train_all_parameters: 196 | unet.requires_grad_(False) 197 | for name, param in unet.named_parameters(): 198 | for trainable_module_name in trainable_modules: 199 | if trainable_module_name in name: 200 | param.requires_grad = True 201 | break 202 | 203 | # Enable xformers 204 | if enable_xformers_memory_efficient_attention and int(torch.__version__.split(".")[0]) < 2: 205 | if is_xformers_available(): 206 | unet.enable_xformers_memory_efficient_attention() 207 | else: 208 | raise ValueError("xformers is not available. Make sure it is installed correctly") 209 | 210 | def save_model_hook(models, weights, output_dir): 211 | if accelerator.is_main_process: 212 | if use_ema: 213 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) 214 | 215 | for i, model in enumerate(models): 216 | model.save_pretrained(os.path.join(output_dir, "unet")) 217 | 218 | # make sure to pop weight so that corresponding model is not saved again 219 | weights.pop() 220 | 221 | def load_model_hook(models, input_dir): 222 | if use_ema: 223 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), VideoLDMUNet3DConditionModel) 224 | ema_unet.load_state_dict(load_model.state_dict()) 225 | ema_unet.to(accelerator.device) 226 | del load_model 227 | 228 | for i in range(len(models)): 229 | # pop models so that they are not loaded again 230 | model = models.pop() 231 | 232 | # load diffusers style into model 233 | load_model = VideoLDMUNet3DConditionModel.from_pretrained(input_dir, subfolder="unet") 234 | model.register_to_config(**load_model.config) 235 | 236 | model.load_state_dict(load_model.state_dict()) 237 | del load_model 238 | 239 | accelerator.register_save_state_pre_hook(save_model_hook) 240 | accelerator.register_load_state_pre_hook(load_model_hook) 241 | 242 | # Enable gradient checkpointing 243 | if gradient_checkpointing: 244 | unet.enable_gradient_checkpointing() 245 | 246 | if scale_lr: 247 | learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes) 248 | 249 | trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) 250 | optimizer = torch.optim.AdamW( 251 | trainable_params, 252 | lr=learning_rate, 253 | betas=(adam_beta1, adam_beta2), 254 | weight_decay=adam_weight_decay, 255 | eps=adam_epsilon, 256 | ) 257 | 258 | logger.info(f"trainable params number: {len(trainable_params)}") 259 | logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") 260 | 261 | # Get the training dataset 262 | if train_data['dataset'] == "webvid": 263 | train_dataset = WebVid10M(**train_data, is_image=is_image) 264 | elif train_data['dataset'] == "pexels": 265 | train_dataset = Pexels(**train_data, is_image=is_image) 266 | elif train_data['dataset'] == "joint": 267 | train_dataset = JointDataset(**train_data, is_image=is_image) 268 | else: 269 | raise ValueError(f"Unknown dataset {train_data['dataset']}") 270 | 271 | # DataLoaders creation: 272 | train_dataloader = torch.utils.data.DataLoader( 273 | train_dataset, 274 | shuffle=True, 275 | batch_size=train_batch_size, 276 | num_workers=num_workers, 277 | pin_memory=True, 278 | ) 279 | 280 | # Get the training iteration 281 | if max_train_steps == -1: 282 | assert max_train_epoch != -1 283 | max_train_steps = max_train_epoch * len(train_dataloader) 284 | 285 | if checkpointing_steps == -1: 286 | assert checkpointing_epochs != -1 287 | checkpointing_steps = checkpointing_epochs * len(train_dataloader) 288 | 289 | # Scheduler 290 | lr_scheduler = get_scheduler( 291 | lr_scheduler, 292 | optimizer=optimizer, 293 | num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, 294 | num_training_steps=max_train_steps * gradient_accumulation_steps, 295 | ) 296 | 297 | # Validation pipeline 298 | validation_pipeline = ConditionalAnimationPipeline( 299 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, 300 | ) 301 | validation_pipeline.enable_vae_slicing() 302 | 303 | # Prepare everything with our `accelerator`. 304 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 305 | unet, optimizer, train_dataloader, lr_scheduler 306 | ) 307 | 308 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 309 | # as these models are only used for inference, keeping weights in full precision is not required. 310 | weight_dtype = torch.float32 311 | if accelerator.mixed_precision == "fp16": 312 | weight_dtype = torch.float16 313 | elif accelerator.mixed_precision == "bf16": 314 | weight_dtype = torch.bfloat16 315 | 316 | if use_ema: 317 | ema_unet.to(accelerator.device) 318 | 319 | # Move text_encode and vae to gpu and cast to weight_dtype 320 | text_encoder.to(accelerator.device, dtype=weight_dtype) 321 | vae.to(accelerator.device, dtype=weight_dtype) 322 | 323 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 324 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 325 | # Afterwards we recalculate our number of training epochs 326 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 327 | 328 | # Train! 329 | total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps 330 | 331 | logger.info("***** Running training *****") 332 | logger.info(f" Num examples = {len(train_dataset)}") 333 | logger.info(f" Num Epochs = {num_train_epochs}") 334 | logger.info(f" Instantaneous batch size per device = {train_batch_size}") 335 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 336 | logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") 337 | logger.info(f" Total optimization steps = {max_train_steps}") 338 | 339 | global_step = 0 340 | first_epoch = 0 341 | 342 | # Load pretrained unet weights 343 | if resume_from_checkpoint is not None: 344 | logger.info(f"Resuming from checkpoint: {resume_from_checkpoint}") 345 | accelerator.load_state(resume_from_checkpoint) 346 | global_step = int(resume_from_checkpoint.split("-")[-1]) 347 | 348 | initial_global_step = global_step 349 | first_epoch = global_step // num_update_steps_per_epoch 350 | logger.info(f"global_step: {global_step}") 351 | logger.info(f"first_epoch: {first_epoch}") 352 | else: 353 | initial_global_step = 0 354 | 355 | # Only show the progress bar once on each machine. 356 | progress_bar = tqdm(range(0, max_train_steps), initial=initial_global_step, desc="Steps", disable=not accelerator.is_main_process) 357 | 358 | for epoch in range(first_epoch, num_train_epochs): 359 | train_loss = 0.0 360 | train_grad_norm = 0.0 361 | data_loading_time = 0.0 362 | prepare_everything_time = 0.0 363 | network_forward_time = 0.0 364 | network_backward_time = 0.0 365 | 366 | t0 = time.time() 367 | for step, batch in enumerate(train_dataloader): 368 | t1 = time.time() 369 | if cfg_random_null_text_ratio > 0.0: 370 | batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']] 371 | 372 | # Data batch sanity check 373 | if accelerator.is_main_process and epoch == first_epoch and step == 0: 374 | pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] 375 | pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") 376 | for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): 377 | pixel_value = pixel_value[None, ...] 378 | save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'no_text-{idx}'}.gif", rescale=True) 379 | 380 | ### >>>> Training >>>> ### 381 | with accelerator.accumulate(unet): 382 | # Convert videos to latent space 383 | pixel_values = batch["pixel_values"].to(weight_dtype) 384 | video_length = pixel_values.shape[1] 385 | pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") 386 | latents = vae.encode(pixel_values).latent_dist 387 | latents = latents.sample() 388 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) 389 | 390 | latents = latents * vae.config.scaling_factor 391 | 392 | if unet_additional_kwargs["first_frame_condition_mode"] != "none": 393 | # Get first frame latents 394 | first_frame_latents = latents[:, :, 0:1, :, :] 395 | 396 | # Sample noise that we'll add to the latents 397 | if unet_additional_kwargs['noise_sampling_method'] == 'vanilla': 398 | noise = torch.randn_like(latents) 399 | elif unet_additional_kwargs['noise_sampling_method'] == 'pyoco_mixed': 400 | noise_alpha_squared = float(unet_additional_kwargs['noise_alpha']) ** 2 401 | shared_noise = torch.randn_like(latents[:, :, 0:1, :, :]) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) 402 | ind_noise = torch.randn_like(latents) * math.sqrt(1 / (1 + noise_alpha_squared)) 403 | noise = shared_noise + ind_noise 404 | elif unet_additional_kwargs['noise_sampling_method'] == 'pyoco_progressive': 405 | noise_alpha_squared = float(unet_additional_kwargs['noise_alpha']) ** 2 406 | noise = torch.randn_like(latents) 407 | ind_noise = torch.randn_like(latents) * math.sqrt(1 / (1 + noise_alpha_squared)) 408 | for i in range(1, noise.shape[2]): 409 | noise[:, :, i, :, :] = noise[:, :, i - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_noise[:, :, i, :, :] 410 | else: 411 | raise ValueError(f"Unknown noise sampling method {unet_additional_kwargs['noise_sampling_method']}") 412 | 413 | bsz = latents.shape[0] 414 | 415 | # Sample a random timestep for each video 416 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 417 | timesteps = timesteps.long() 418 | 419 | # Add noise to the latents according to the noise magnitude at each timestep 420 | # (this is the forward diffusion process) 421 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 422 | 423 | if cfg_random_null_img_ratio > 0.0: 424 | for i in range(first_frame_latents.shape[0]): 425 | if random.random() <= cfg_random_null_img_ratio: 426 | first_frame_latents[i, :, :, :, :] = noisy_latents[i, :, 0:1, :, :] 427 | 428 | # Remove the first noisy latent from the latents if we're conditioning on the first frame 429 | if unet_additional_kwargs["first_frame_condition_mode"] != "none": 430 | noisy_latents = noisy_latents[:, :, 1:, :, :] 431 | 432 | # Get the text embedding for conditioning 433 | prompt_ids = tokenizer( 434 | batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 435 | ).input_ids.to(latents.device) 436 | encoder_hidden_states = text_encoder(prompt_ids)[0] 437 | 438 | # Get the target for loss depending on the prediction type 439 | if noise_scheduler.config.prediction_type == "epsilon": 440 | target = noise 441 | elif noise_scheduler.config.prediction_type == "v_prediction": 442 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 443 | else: 444 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 445 | 446 | timesteps = repeat(timesteps, "b -> b f", f=video_length) 447 | timesteps = rearrange(timesteps, "b f -> (b f)") 448 | 449 | frame_stride = None 450 | if unet_additional_kwargs["use_frame_stride_condition"]: 451 | frame_stride = batch['stride'].to(latents.device) 452 | frame_stride = frame_stride.long() 453 | frame_stride = repeat(frame_stride, "b -> b f", f=video_length) 454 | frame_stride = rearrange(frame_stride, "b f -> (b f)") 455 | 456 | t2 = time.time() 457 | 458 | # Predict the noise residual and compute loss 459 | if unet_additional_kwargs["first_frame_condition_mode"] != "none": 460 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, first_frame_latents=first_frame_latents, frame_stride=frame_stride).sample 461 | loss = F.mse_loss(model_pred.float(), target.float()[:, :, 1:, :, :], reduction="mean") 462 | else: 463 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 464 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 465 | 466 | t3 = time.time() 467 | 468 | avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() 469 | train_loss += avg_loss.item() / gradient_accumulation_steps 470 | 471 | # Backpropagate 472 | accelerator.backward(loss) 473 | if accelerator.sync_gradients: 474 | grad_norm = accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm) 475 | avg_grad_norm = accelerator.gather(grad_norm.repeat(train_batch_size)).mean() 476 | train_grad_norm += avg_grad_norm.item() / gradient_accumulation_steps 477 | 478 | optimizer.step() 479 | lr_scheduler.step() 480 | optimizer.zero_grad() 481 | 482 | t4 = time.time() 483 | 484 | data_loading_time += (t1 - t0) / gradient_accumulation_steps 485 | prepare_everything_time += (t2 - t1) / gradient_accumulation_steps 486 | network_forward_time += (t3 - t2) / gradient_accumulation_steps 487 | network_backward_time += (t4 - t3) / gradient_accumulation_steps 488 | 489 | t0 = time.time() 490 | 491 | ### <<<< Training <<<< ### 492 | 493 | # Checks if the accelerator has performed an optimization step behind the scenes 494 | if accelerator.sync_gradients: 495 | if use_ema: 496 | ema_unet.step(unet.parameters()) 497 | progress_bar.update(1) 498 | global_step += 1 499 | 500 | # Wandb logging 501 | if accelerator.is_main_process and (not is_debug) and use_wandb: 502 | wandb.log({"metrics/train_loss": train_loss}, step=global_step) 503 | wandb.log({"metrics/train_grad_norm": train_grad_norm}, step=global_step) 504 | 505 | wandb.log({"profiling/train_data_loading_time": data_loading_time}, step=global_step) 506 | wandb.log({"profiling/train_prepare_everything_time": prepare_everything_time}, step=global_step) 507 | wandb.log({"profiling/train_network_forward_time": network_forward_time}, step=global_step) 508 | wandb.log({"profiling/train_network_backward_time": network_backward_time}, step=global_step) 509 | # accelerator.log({"train_loss": train_loss}, step=global_step) 510 | train_loss = 0.0 511 | train_grad_norm = 0.0 512 | data_loading_time = 0.0 513 | prepare_everything_time = 0.0 514 | network_forward_time = 0.0 515 | network_backward_time = 0.0 516 | 517 | # Save checkpoint 518 | if global_step % checkpointing_steps == 0: 519 | if accelerator.is_main_process: 520 | save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}") 521 | accelerator.save_state(save_path) 522 | logger.info(f"Saved state to {save_path} (global_step: {global_step})") 523 | 524 | # Periodically validation 525 | if accelerator.is_main_process and global_step % validation_steps == 0: 526 | if use_ema: 527 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 528 | ema_unet.store(unet.parameters()) 529 | ema_unet.copy_to(unet.parameters()) 530 | 531 | samples = [] 532 | wandb_samples = [] 533 | 534 | generator = torch.Generator(device=latents.device) 535 | generator.manual_seed(seed) 536 | 537 | height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size 538 | width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size 539 | 540 | prompts = validation_data.prompts 541 | 542 | first_frame_paths = [None] * len(prompts) 543 | if unet_additional_kwargs["first_frame_condition_mode"] != "none": 544 | first_frame_paths = validation_data.path_to_first_frames 545 | 546 | for idx, (prompt, first_frame_path) in enumerate(zip(prompts, first_frame_paths)): 547 | sample = validation_pipeline( 548 | prompt, 549 | generator = generator, 550 | video_length = train_data.sample_n_frames if not is_image else 2, 551 | height = height, 552 | width = width, 553 | first_frame_paths = first_frame_path, 554 | noise_sampling_method = unet_additional_kwargs['noise_sampling_method'], 555 | noise_alpha = float(unet_additional_kwargs['noise_alpha']), 556 | **validation_data, 557 | ).videos 558 | save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") 559 | samples.append(sample) 560 | 561 | numpy_sample = (sample.squeeze(0).permute(1, 0, 2, 3) * 255).cpu().numpy().astype(np.uint8) 562 | wandb_video = wandb.Video(numpy_sample, fps=8, caption=prompt) 563 | wandb_samples.append(wandb_video) 564 | 565 | if (not is_debug) and use_wandb: 566 | val_title = 'val_videos' 567 | wandb.log({val_title: wandb_samples}, step=global_step) 568 | 569 | samples = torch.concat(samples) 570 | save_path = f"{output_dir}/samples/sample-{global_step}.gif" 571 | save_videos_grid(samples, save_path) 572 | 573 | logger.info(f"Saved samples to {save_path}") 574 | 575 | if use_ema: 576 | # Switch back to the original UNet parameters. 577 | ema_unet.restore(unet.parameters()) 578 | 579 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 580 | progress_bar.set_postfix(**logs) 581 | 582 | if accelerator.is_main_process and (not is_debug) and use_wandb: 583 | wandb.log({"metrics/train_lr": lr_scheduler.get_last_lr()[0]}, step=global_step) 584 | 585 | if global_step >= max_train_steps: 586 | break 587 | 588 | # Create the pipeline using the trained modules and save it. 589 | accelerator.wait_for_everyone() 590 | if accelerator.is_main_process: 591 | unet = accelerator.unwrap_model(unet) 592 | pipeline = ConditionalAnimationPipeline( 593 | text_encoder=text_encoder, 594 | vae=vae, 595 | unet=unet, 596 | tokenizer=tokenizer, 597 | scheduler=noise_scheduler, 598 | ) 599 | pipeline.save_pretrained(f"{output_dir}/final_checkpoint") 600 | 601 | 602 | if __name__ == "__main__": 603 | parser = argparse.ArgumentParser() 604 | parser.add_argument("--config", type=str, required=True) 605 | parser.add_argument("--name", "-n", type=str, default="") 606 | parser.add_argument("--wandb", action="store_true") 607 | parser.add_argument("optional_args", nargs='*', default=[]) 608 | args = parser.parse_args() 609 | 610 | name = args.name + "_" + Path(args.config).stem 611 | config = OmegaConf.load(args.config) 612 | 613 | if args.optional_args: 614 | modified_config = OmegaConf.from_dotlist(args.optional_args) 615 | config = OmegaConf.merge(config, modified_config) 616 | 617 | main(name=name, use_wandb=args.wandb, **config) 618 | -------------------------------------------------------------------------------- /consisti2v/pipelines/pipeline_autoregress_animation.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py 2 | 3 | import inspect 4 | from typing import Callable, List, Optional, Union 5 | from dataclasses import dataclass 6 | 7 | import math 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from torchvision import transforms as T 13 | from PIL import Image 14 | 15 | from diffusers.utils import is_accelerate_available 16 | from packaging import version 17 | from transformers import CLIPTextModel, CLIPTokenizer 18 | 19 | from diffusers.configuration_utils import FrozenDict 20 | from diffusers.models import AutoencoderKL 21 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 22 | from diffusers.schedulers import ( 23 | DDIMScheduler, 24 | DPMSolverMultistepScheduler, 25 | EulerAncestralDiscreteScheduler, 26 | EulerDiscreteScheduler, 27 | LMSDiscreteScheduler, 28 | PNDMScheduler, 29 | ) 30 | from diffusers.utils import deprecate, logging, BaseOutput 31 | 32 | from einops import rearrange, repeat 33 | 34 | from ..models.unet import UNet3DConditionModel 35 | from ..utils.frameinit_utils import freq_mix_3d, get_freq_filter 36 | 37 | 38 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 39 | 40 | # copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21 41 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 42 | """ 43 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 44 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 45 | """ 46 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 47 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 48 | # rescale the results from guidance (fixes overexposure) 49 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 50 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 51 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 52 | return noise_cfg 53 | 54 | 55 | @dataclass 56 | class AnimationPipelineOutput(BaseOutput): 57 | videos: Union[torch.Tensor, np.ndarray] 58 | 59 | 60 | class AutoregressiveAnimationPipeline(DiffusionPipeline): 61 | _optional_components = [] 62 | 63 | def __init__( 64 | self, 65 | vae: AutoencoderKL, 66 | text_encoder: CLIPTextModel, 67 | tokenizer: CLIPTokenizer, 68 | unet: UNet3DConditionModel, 69 | scheduler: Union[ 70 | DDIMScheduler, 71 | PNDMScheduler, 72 | LMSDiscreteScheduler, 73 | EulerDiscreteScheduler, 74 | EulerAncestralDiscreteScheduler, 75 | DPMSolverMultistepScheduler, 76 | ], 77 | ): 78 | super().__init__() 79 | 80 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 81 | deprecation_message = ( 82 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 83 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 84 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 85 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 86 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 87 | " file" 88 | ) 89 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 90 | new_config = dict(scheduler.config) 91 | new_config["steps_offset"] = 1 92 | scheduler._internal_dict = FrozenDict(new_config) 93 | 94 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 95 | deprecation_message = ( 96 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 97 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 98 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 99 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 100 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 101 | ) 102 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 103 | new_config = dict(scheduler.config) 104 | new_config["clip_sample"] = False 105 | scheduler._internal_dict = FrozenDict(new_config) 106 | 107 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 108 | version.parse(unet.config._diffusers_version).base_version 109 | ) < version.parse("0.9.0.dev0") 110 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 111 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 112 | deprecation_message = ( 113 | "The configuration file of the unet has set the default `sample_size` to smaller than" 114 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 115 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 116 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 117 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 118 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 119 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 120 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 121 | " the `unet/config.json` file" 122 | ) 123 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 124 | new_config = dict(unet.config) 125 | new_config["sample_size"] = 64 126 | unet._internal_dict = FrozenDict(new_config) 127 | 128 | self.register_modules( 129 | vae=vae, 130 | text_encoder=text_encoder, 131 | tokenizer=tokenizer, 132 | unet=unet, 133 | scheduler=scheduler, 134 | ) 135 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 136 | 137 | self.freq_filter = None 138 | 139 | @torch.no_grad() 140 | def init_filter(self, video_length, height, width, filter_params): 141 | # initialize frequency filter for noise reinitialization 142 | batch_size = 1 143 | num_channels_latents = self.unet.config.in_channels 144 | filter_shape = [ 145 | batch_size, 146 | num_channels_latents, 147 | video_length, 148 | height // self.vae_scale_factor, 149 | width // self.vae_scale_factor 150 | ] 151 | # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params) 152 | self.freq_filter = get_freq_filter( 153 | filter_shape, 154 | device=self._execution_device, 155 | filter_type=filter_params.method, 156 | n=filter_params.n if filter_params.method=="butterworth" else None, 157 | d_s=filter_params.d_s, 158 | d_t=filter_params.d_t 159 | ) 160 | 161 | def enable_vae_slicing(self): 162 | self.vae.enable_slicing() 163 | 164 | def disable_vae_slicing(self): 165 | self.vae.disable_slicing() 166 | 167 | def enable_sequential_cpu_offload(self, gpu_id=0): 168 | if is_accelerate_available(): 169 | from accelerate import cpu_offload 170 | else: 171 | raise ImportError("Please install accelerate via `pip install accelerate`") 172 | 173 | device = torch.device(f"cuda:{gpu_id}") 174 | 175 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 176 | if cpu_offloaded_model is not None: 177 | cpu_offload(cpu_offloaded_model, device) 178 | 179 | 180 | @property 181 | def _execution_device(self): 182 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 183 | return self.device 184 | for module in self.unet.modules(): 185 | if ( 186 | hasattr(module, "_hf_hook") 187 | and hasattr(module._hf_hook, "execution_device") 188 | and module._hf_hook.execution_device is not None 189 | ): 190 | return torch.device(module._hf_hook.execution_device) 191 | return self.device 192 | 193 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 194 | batch_size = len(prompt) if isinstance(prompt, list) else 1 195 | 196 | text_inputs = self.tokenizer( 197 | prompt, 198 | padding="max_length", 199 | max_length=self.tokenizer.model_max_length, 200 | truncation=True, 201 | return_tensors="pt", 202 | ) 203 | text_input_ids = text_inputs.input_ids 204 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 205 | 206 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 207 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 208 | logger.warning( 209 | "The following part of your input was truncated because CLIP can only handle sequences up to" 210 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 211 | ) 212 | 213 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 214 | attention_mask = text_inputs.attention_mask.to(device) 215 | else: 216 | attention_mask = None 217 | 218 | text_embeddings = self.text_encoder( 219 | text_input_ids.to(device), 220 | attention_mask=attention_mask, 221 | ) 222 | text_embeddings = text_embeddings[0] 223 | 224 | # duplicate text embeddings for each generation per prompt, using mps friendly method 225 | bs_embed, seq_len, _ = text_embeddings.shape 226 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 227 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 228 | 229 | # get unconditional embeddings for classifier free guidance 230 | if do_classifier_free_guidance is not None: 231 | uncond_tokens: List[str] 232 | if negative_prompt is None: 233 | uncond_tokens = [""] * batch_size 234 | elif type(prompt) is not type(negative_prompt): 235 | raise TypeError( 236 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 237 | f" {type(prompt)}." 238 | ) 239 | elif isinstance(negative_prompt, str): 240 | uncond_tokens = [negative_prompt] 241 | elif batch_size != len(negative_prompt): 242 | raise ValueError( 243 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 244 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 245 | " the batch size of `prompt`." 246 | ) 247 | else: 248 | uncond_tokens = negative_prompt 249 | 250 | max_length = text_input_ids.shape[-1] 251 | uncond_input = self.tokenizer( 252 | uncond_tokens, 253 | padding="max_length", 254 | max_length=max_length, 255 | truncation=True, 256 | return_tensors="pt", 257 | ) 258 | 259 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 260 | attention_mask = uncond_input.attention_mask.to(device) 261 | else: 262 | attention_mask = None 263 | 264 | uncond_embeddings = self.text_encoder( 265 | uncond_input.input_ids.to(device), 266 | attention_mask=attention_mask, 267 | ) 268 | uncond_embeddings = uncond_embeddings[0] 269 | 270 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 271 | seq_len = uncond_embeddings.shape[1] 272 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 273 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 274 | 275 | # For classifier free guidance, we need to do two forward passes. 276 | # Here we concatenate the unconditional and text embeddings into a single batch 277 | # to avoid doing two forward passes 278 | if do_classifier_free_guidance == "text": 279 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 280 | elif do_classifier_free_guidance == "both": 281 | text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings]) 282 | 283 | return text_embeddings 284 | 285 | def decode_latents(self, latents, first_frames=None): 286 | video_length = latents.shape[2] 287 | latents = 1 / self.vae.config.scaling_factor * latents 288 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 289 | # video = self.vae.decode(latents).sample 290 | video = [] 291 | for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config): 292 | video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) 293 | video = torch.cat(video) 294 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 295 | 296 | if first_frames is not None: 297 | first_frames = first_frames.unsqueeze(2) 298 | video = torch.cat([first_frames, video], dim=2) 299 | 300 | video = (video / 2 + 0.5).clamp(0, 1) 301 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 302 | video = video.cpu().float().numpy() 303 | return video 304 | 305 | def prepare_extra_step_kwargs(self, generator, eta): 306 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 307 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 308 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 309 | # and should be between [0, 1] 310 | 311 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 312 | extra_step_kwargs = {} 313 | if accepts_eta: 314 | extra_step_kwargs["eta"] = eta 315 | 316 | # check if the scheduler accepts generator 317 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 318 | if accepts_generator: 319 | extra_step_kwargs["generator"] = generator 320 | return extra_step_kwargs 321 | 322 | def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None): 323 | if not isinstance(prompt, str) and not isinstance(prompt, list): 324 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 325 | 326 | if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)): 327 | raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}") 328 | 329 | if height % 8 != 0 or width % 8 != 0: 330 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 331 | 332 | if (callback_steps is None) or ( 333 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 334 | ): 335 | raise ValueError( 336 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 337 | f" {type(callback_steps)}." 338 | ) 339 | 340 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0): 341 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) 342 | if isinstance(generator, list) and len(generator) != batch_size: 343 | raise ValueError( 344 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 345 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 346 | ) 347 | if latents is None: 348 | rand_device = "cpu" if device.type == "mps" else device 349 | 350 | if isinstance(generator, list): 351 | # shape = shape 352 | shape = (1,) + shape[1:] 353 | if noise_sampling_method == "vanilla": 354 | latents = [ 355 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 356 | for i in range(batch_size) 357 | ] 358 | elif noise_sampling_method == "pyoco_mixed": 359 | base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) 360 | latents = [] 361 | noise_alpha_squared = noise_alpha ** 2 362 | for i in range(batch_size): 363 | base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) 364 | ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) 365 | latents.append(base_latent + ind_latent) 366 | elif noise_sampling_method == "pyoco_progressive": 367 | latents = [] 368 | noise_alpha_squared = noise_alpha ** 2 369 | for i in range(batch_size): 370 | latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 371 | ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) 372 | for j in range(1, video_length): 373 | latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :] 374 | latents.append(latent) 375 | latents = torch.cat(latents, dim=0).to(device) 376 | else: 377 | if noise_sampling_method == "vanilla": 378 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 379 | elif noise_sampling_method == "pyoco_mixed": 380 | noise_alpha_squared = noise_alpha ** 2 381 | base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) 382 | base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) 383 | ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) 384 | latents = base_latents + ind_latents 385 | elif noise_sampling_method == "pyoco_progressive": 386 | noise_alpha_squared = noise_alpha ** 2 387 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) 388 | ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) 389 | for j in range(1, video_length): 390 | latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :] 391 | else: 392 | if latents.shape != shape: 393 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 394 | latents = latents.to(device) 395 | 396 | # scale the initial noise by the standard deviation required by the scheduler 397 | latents = latents * self.scheduler.init_noise_sigma 398 | return latents 399 | 400 | @torch.no_grad() 401 | def __call__( 402 | self, 403 | prompt: Union[str, List[str]], 404 | video_length: Optional[int], 405 | height: Optional[int] = None, 406 | width: Optional[int] = None, 407 | num_inference_steps: int = 50, 408 | guidance_scale_txt: float = 7.5, 409 | guidance_scale_img: float = 2.0, 410 | negative_prompt: Optional[Union[str, List[str]]] = None, 411 | num_videos_per_prompt: Optional[int] = 1, 412 | eta: float = 0.0, 413 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 414 | latents: Optional[torch.FloatTensor] = None, 415 | output_type: Optional[str] = "tensor", 416 | return_dict: bool = True, 417 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 418 | callback_steps: Optional[int] = 1, 419 | # additional 420 | first_frame_paths: Optional[Union[str, List[str]]] = None, 421 | first_frames: Optional[torch.FloatTensor] = None, 422 | noise_sampling_method: str = "vanilla", 423 | noise_alpha: float = 1.0, 424 | guidance_rescale: float = 0.0, 425 | frame_stride: Optional[int] = None, 426 | autoregress_steps: int = 3, 427 | use_frameinit: bool = False, 428 | frameinit_noise_level: int = 999, 429 | **kwargs, 430 | ): 431 | if first_frame_paths is not None and first_frames is not None: 432 | raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.") 433 | # Default height and width to unet 434 | height = height or self.unet.config.sample_size * self.vae_scale_factor 435 | width = width or self.unet.config.sample_size * self.vae_scale_factor 436 | 437 | # Check inputs. Raise error if not correct 438 | self.check_inputs(prompt, height, width, callback_steps, first_frame_paths) 439 | 440 | # Define call parameters 441 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 442 | batch_size = 1 443 | if latents is not None: 444 | batch_size = latents.shape[0] 445 | if isinstance(prompt, list): 446 | batch_size = len(prompt) 447 | first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames 448 | if first_frame_input is not None: 449 | assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length" 450 | 451 | device = self._execution_device 452 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 453 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 454 | # corresponds to doing no classifier free guidance. 455 | do_classifier_free_guidance = None 456 | # two guidance mode: text and text+image 457 | if guidance_scale_txt > 1.0: 458 | do_classifier_free_guidance = "text" 459 | if guidance_scale_img > 1.0: 460 | do_classifier_free_guidance = "both" 461 | 462 | # Encode input prompt 463 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size 464 | if negative_prompt is not None: 465 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 466 | text_embeddings = self._encode_prompt( 467 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 468 | ) 469 | 470 | # Encode input first frame 471 | first_frame_latents = None 472 | if first_frame_paths is not None: 473 | first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size 474 | img_transform = T.Compose([ 475 | T.ToTensor(), 476 | T.Resize(height, antialias=None), 477 | T.CenterCrop((height, width)), 478 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 479 | ]) 480 | first_frames = [] 481 | for first_frame_path in first_frame_paths: 482 | first_frame = Image.open(first_frame_path).convert('RGB') 483 | first_frame = img_transform(first_frame).unsqueeze(0) 484 | first_frames.append(first_frame) 485 | first_frames = torch.cat(first_frames, dim=0) 486 | if first_frames is not None: 487 | first_frames = first_frames.to(device, dtype=self.vae.dtype) 488 | first_frame_latents = self.vae.encode(first_frames).latent_dist 489 | first_frame_latents = first_frame_latents.sample() 490 | first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w 491 | first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt) 492 | first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt) 493 | 494 | full_video_latent = torch.zeros(batch_size * num_videos_per_prompt, self.unet.config.in_channels, video_length * autoregress_steps - autoregress_steps + 1, height // self.vae_scale_factor, width // self.vae_scale_factor, device=device, dtype=self.vae.dtype) 495 | 496 | start_idx = 0 497 | for ar_step in range(autoregress_steps): 498 | # Prepare timesteps 499 | self.scheduler.set_timesteps(num_inference_steps, device=device) 500 | timesteps = self.scheduler.timesteps 501 | 502 | # Prepare latent variables 503 | num_channels_latents = self.unet.config.in_channels 504 | latents = self.prepare_latents( 505 | batch_size * num_videos_per_prompt, 506 | num_channels_latents, 507 | video_length, 508 | height, 509 | width, 510 | text_embeddings.dtype, 511 | device, 512 | generator, 513 | latents, 514 | noise_sampling_method, 515 | noise_alpha, 516 | ) 517 | latents_dtype = latents.dtype 518 | 519 | if use_frameinit: 520 | current_diffuse_timestep = frameinit_noise_level # diffuse to noise level 521 | diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep)) 522 | diffuse_timesteps = diffuse_timesteps.long() 523 | first_frames_static_vid = repeat(first_frame_latents, "b c h w -> b c t h w", t=video_length) 524 | z_T = self.scheduler.add_noise( 525 | original_samples=first_frames_static_vid.to(device), 526 | noise=latents.to(device), 527 | timesteps=diffuse_timesteps.to(device) 528 | ) 529 | latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents, LPF=self.freq_filter) 530 | latents = latents.to(dtype=latents_dtype) 531 | 532 | if first_frame_latents is not None: 533 | first_frame_noisy_latent = latents[:, :, 0, :, :] 534 | latents = latents[:, :, 1:, :, :] 535 | 536 | # Prepare extra step kwargs. 537 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 538 | 539 | # Denoising loop 540 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 541 | with self.progress_bar(total=num_inference_steps) as progress_bar: 542 | for i, t in enumerate(timesteps): 543 | # expand the latents if we are doing classifier free guidance 544 | if do_classifier_free_guidance is None: 545 | latent_model_input = latents 546 | elif do_classifier_free_guidance == "text": 547 | latent_model_input = torch.cat([latents] * 2) 548 | elif do_classifier_free_guidance == "both": 549 | latent_model_input = torch.cat([latents] * 3) 550 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 551 | if first_frame_latents is not None: 552 | if do_classifier_free_guidance is None: 553 | first_frame_latents_input = first_frame_latents 554 | elif do_classifier_free_guidance == "text": 555 | first_frame_latents_input = torch.cat([first_frame_latents] * 2) 556 | elif do_classifier_free_guidance == "both": 557 | first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents]) 558 | 559 | first_frame_latents_input = first_frame_latents_input.unsqueeze(2) 560 | 561 | # predict the noise residual 562 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype) 563 | else: 564 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) 565 | # noise_pred = [] 566 | # import pdb 567 | # pdb.set_trace() 568 | # for batch_idx in range(latent_model_input.shape[0]): 569 | # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype) 570 | # noise_pred.append(noise_pred_single) 571 | # noise_pred = torch.cat(noise_pred) 572 | 573 | # perform guidance 574 | if do_classifier_free_guidance: 575 | if do_classifier_free_guidance == "text": 576 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 577 | noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond) 578 | elif do_classifier_free_guidance == "both": 579 | noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3) 580 | noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img) 581 | 582 | if do_classifier_free_guidance and guidance_rescale > 0.0: 583 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 584 | # currently only support text guidance 585 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 586 | 587 | # compute the previous noisy sample x_t -> x_t-1 588 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 589 | 590 | # call the callback, if provided 591 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 592 | progress_bar.update() 593 | if callback is not None and i % callback_steps == 0: 594 | callback(i, t, latents) 595 | 596 | # Post-processing 597 | 598 | latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2) 599 | first_frame_latents = latents[:, :, -1, :, :] 600 | full_video_latent[:, :, start_idx:start_idx + video_length, :, :] = latents 601 | 602 | latents = None 603 | start_idx += (video_length - 1) 604 | 605 | # video = self.decode_latents(latents, first_frames) 606 | video = self.decode_latents(full_video_latent) 607 | 608 | # Convert to tensor 609 | if output_type == "tensor": 610 | video = torch.from_numpy(video) 611 | 612 | if not return_dict: 613 | return video 614 | 615 | return AnimationPipelineOutput(videos=video) 616 | -------------------------------------------------------------------------------- /consisti2v/pipelines/pipeline_conditional_animation.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py 2 | 3 | import inspect 4 | from typing import Callable, List, Optional, Union 5 | from dataclasses import dataclass 6 | 7 | import math 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from torchvision import transforms as T 13 | from torchvision.transforms import functional as F 14 | from PIL import Image 15 | 16 | from diffusers.utils import is_accelerate_available 17 | from packaging import version 18 | from transformers import CLIPTextModel, CLIPTokenizer 19 | 20 | from diffusers.configuration_utils import FrozenDict 21 | from diffusers.models import AutoencoderKL 22 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 23 | from diffusers.schedulers import ( 24 | DDIMScheduler, 25 | DPMSolverMultistepScheduler, 26 | EulerAncestralDiscreteScheduler, 27 | EulerDiscreteScheduler, 28 | LMSDiscreteScheduler, 29 | PNDMScheduler, 30 | ) 31 | from diffusers.utils import deprecate, logging, BaseOutput 32 | 33 | from einops import rearrange, repeat 34 | 35 | from ..models.videoldm_unet import VideoLDMUNet3DConditionModel 36 | 37 | from ..utils.frameinit_utils import get_freq_filter, freq_mix_3d 38 | 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | # copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21 43 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 44 | """ 45 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 46 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 47 | """ 48 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 49 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 50 | # rescale the results from guidance (fixes overexposure) 51 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 52 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 53 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 54 | return noise_cfg 55 | 56 | def pan_right(image, num_frames=16, crop_width=256): 57 | frames = [] 58 | height, width = image.shape[-2:] 59 | 60 | for i in range(num_frames): 61 | # Calculate the start position of the crop 62 | start_x = int((width - crop_width) * (i / num_frames)) 63 | crop = F.crop(image, 0, start_x, height, crop_width) 64 | frames.append(crop.unsqueeze(0)) 65 | 66 | return torch.cat(frames, dim=0) 67 | 68 | 69 | def pan_left(image, num_frames=16, crop_width=256): 70 | frames = [] 71 | height, width = image.shape[-2:] 72 | 73 | for i in range(num_frames): 74 | # Start position moves from right to left 75 | start_x = int((width - crop_width) * (1 - (i / num_frames))) 76 | crop = F.crop(image, 0, start_x, height, crop_width) 77 | frames.append(crop.unsqueeze(0)) 78 | 79 | return torch.cat(frames, dim=0) 80 | 81 | 82 | def zoom_in(image, num_frames=16, crop_width=256, ratio=1.5): 83 | frames = [] 84 | height, width = image.shape[-2:] 85 | max_crop_size = min(width, height) 86 | 87 | for i in range(num_frames): 88 | # Calculate the size of the crop 89 | crop_size = max_crop_size - int((max_crop_size - max_crop_size // ratio) * (i / num_frames)) 90 | start_x = (width - crop_size) // 2 91 | start_y = (height - crop_size) // 2 92 | crop = F.crop(image, start_y, start_x, crop_size, crop_size) 93 | resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size 94 | frames.append(resized_crop.unsqueeze(0)) 95 | 96 | return torch.cat(frames, dim=0) 97 | 98 | 99 | def zoom_out(image, num_frames=16, crop_width=256, ratio=1.5): 100 | frames = [] 101 | height, width = image.shape[-2:] 102 | min_crop_size = min(width, height) // ratio # Starting from a quarter of the size 103 | 104 | for i in range(num_frames): 105 | # Calculate the size of the crop 106 | crop_size = min_crop_size + int((min(width, height) - min_crop_size) * (i / num_frames)) 107 | start_x = (width - crop_size) // 2 108 | start_y = (height - crop_size) // 2 109 | crop = F.crop(image, start_y, start_x, crop_size, crop_size) 110 | resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size 111 | frames.append(resized_crop.unsqueeze(0)) 112 | 113 | return torch.cat(frames, dim=0) 114 | 115 | 116 | @dataclass 117 | class AnimationPipelineOutput(BaseOutput): 118 | videos: Union[torch.Tensor, np.ndarray] 119 | 120 | 121 | class ConditionalAnimationPipeline(DiffusionPipeline): 122 | _optional_components = [] 123 | 124 | def __init__( 125 | self, 126 | vae: AutoencoderKL, 127 | text_encoder: CLIPTextModel, 128 | tokenizer: CLIPTokenizer, 129 | unet: VideoLDMUNet3DConditionModel, 130 | scheduler: Union[ 131 | DDIMScheduler, 132 | PNDMScheduler, 133 | LMSDiscreteScheduler, 134 | EulerDiscreteScheduler, 135 | EulerAncestralDiscreteScheduler, 136 | DPMSolverMultistepScheduler, 137 | ], 138 | ): 139 | super().__init__() 140 | 141 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 142 | deprecation_message = ( 143 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 144 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 145 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 146 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 147 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 148 | " file" 149 | ) 150 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 151 | new_config = dict(scheduler.config) 152 | new_config["steps_offset"] = 1 153 | scheduler._internal_dict = FrozenDict(new_config) 154 | 155 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 156 | deprecation_message = ( 157 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 158 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 159 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 160 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 161 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 162 | ) 163 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 164 | new_config = dict(scheduler.config) 165 | new_config["clip_sample"] = False 166 | scheduler._internal_dict = FrozenDict(new_config) 167 | 168 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 169 | version.parse(unet.config._diffusers_version).base_version 170 | ) < version.parse("0.9.0.dev0") 171 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 172 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 173 | deprecation_message = ( 174 | "The configuration file of the unet has set the default `sample_size` to smaller than" 175 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 176 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 177 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 178 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 179 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 180 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 181 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 182 | " the `unet/config.json` file" 183 | ) 184 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 185 | new_config = dict(unet.config) 186 | new_config["sample_size"] = 64 187 | unet._internal_dict = FrozenDict(new_config) 188 | 189 | self.register_modules( 190 | vae=vae, 191 | text_encoder=text_encoder, 192 | tokenizer=tokenizer, 193 | unet=unet, 194 | scheduler=scheduler, 195 | ) 196 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 197 | 198 | self.freq_filter = None 199 | 200 | @torch.no_grad() 201 | def init_filter(self, video_length, height, width, filter_params): 202 | # initialize frequency filter for noise reinitialization 203 | batch_size = 1 204 | num_channels_latents = self.unet.config.in_channels 205 | filter_shape = [ 206 | batch_size, 207 | num_channels_latents, 208 | video_length, 209 | height // self.vae_scale_factor, 210 | width // self.vae_scale_factor 211 | ] 212 | # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params) 213 | self.freq_filter = get_freq_filter( 214 | filter_shape, 215 | device=self._execution_device, 216 | filter_type=filter_params.method, 217 | n=filter_params.n if filter_params.method=="butterworth" else None, 218 | d_s=filter_params.d_s, 219 | d_t=filter_params.d_t 220 | ) 221 | 222 | def enable_vae_slicing(self): 223 | self.vae.enable_slicing() 224 | 225 | def disable_vae_slicing(self): 226 | self.vae.disable_slicing() 227 | 228 | def enable_sequential_cpu_offload(self, gpu_id=0): 229 | if is_accelerate_available(): 230 | from accelerate import cpu_offload 231 | else: 232 | raise ImportError("Please install accelerate via `pip install accelerate`") 233 | 234 | device = torch.device(f"cuda:{gpu_id}") 235 | 236 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 237 | if cpu_offloaded_model is not None: 238 | cpu_offload(cpu_offloaded_model, device) 239 | 240 | 241 | @property 242 | def _execution_device(self): 243 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 244 | return self.device 245 | for module in self.unet.modules(): 246 | if ( 247 | hasattr(module, "_hf_hook") 248 | and hasattr(module._hf_hook, "execution_device") 249 | and module._hf_hook.execution_device is not None 250 | ): 251 | return torch.device(module._hf_hook.execution_device) 252 | return self.device 253 | 254 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 255 | batch_size = len(prompt) if isinstance(prompt, list) else 1 256 | 257 | text_inputs = self.tokenizer( 258 | prompt, 259 | padding="max_length", 260 | max_length=self.tokenizer.model_max_length, 261 | truncation=True, 262 | return_tensors="pt", 263 | ) 264 | text_input_ids = text_inputs.input_ids 265 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 266 | 267 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 268 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 269 | logger.warning( 270 | "The following part of your input was truncated because CLIP can only handle sequences up to" 271 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 272 | ) 273 | 274 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 275 | attention_mask = text_inputs.attention_mask.to(device) 276 | else: 277 | attention_mask = None 278 | 279 | text_embeddings = self.text_encoder( 280 | text_input_ids.to(device), 281 | attention_mask=attention_mask, 282 | ) 283 | text_embeddings = text_embeddings[0] 284 | 285 | # duplicate text embeddings for each generation per prompt, using mps friendly method 286 | bs_embed, seq_len, _ = text_embeddings.shape 287 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 288 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 289 | 290 | # get unconditional embeddings for classifier free guidance 291 | if do_classifier_free_guidance is not None: 292 | uncond_tokens: List[str] 293 | if negative_prompt is None: 294 | uncond_tokens = [""] * batch_size 295 | elif type(prompt) is not type(negative_prompt): 296 | raise TypeError( 297 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 298 | f" {type(prompt)}." 299 | ) 300 | elif isinstance(negative_prompt, str): 301 | uncond_tokens = [negative_prompt] 302 | elif batch_size != len(negative_prompt): 303 | raise ValueError( 304 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 305 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 306 | " the batch size of `prompt`." 307 | ) 308 | else: 309 | uncond_tokens = negative_prompt 310 | 311 | max_length = text_input_ids.shape[-1] 312 | uncond_input = self.tokenizer( 313 | uncond_tokens, 314 | padding="max_length", 315 | max_length=max_length, 316 | truncation=True, 317 | return_tensors="pt", 318 | ) 319 | 320 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 321 | attention_mask = uncond_input.attention_mask.to(device) 322 | else: 323 | attention_mask = None 324 | 325 | uncond_embeddings = self.text_encoder( 326 | uncond_input.input_ids.to(device), 327 | attention_mask=attention_mask, 328 | ) 329 | uncond_embeddings = uncond_embeddings[0] 330 | 331 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 332 | seq_len = uncond_embeddings.shape[1] 333 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 334 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 335 | 336 | # For classifier free guidance, we need to do two forward passes. 337 | # Here we concatenate the unconditional and text embeddings into a single batch 338 | # to avoid doing two forward passes 339 | if do_classifier_free_guidance == "text": 340 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 341 | elif do_classifier_free_guidance == "both": 342 | text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings]) 343 | 344 | return text_embeddings 345 | 346 | def decode_latents(self, latents, first_frames=None): 347 | video_length = latents.shape[2] 348 | latents = 1 / self.vae.config.scaling_factor * latents 349 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 350 | # video = self.vae.decode(latents).sample 351 | video = [] 352 | for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config): 353 | video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) 354 | video = torch.cat(video) 355 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 356 | 357 | if first_frames is not None: 358 | first_frames = first_frames.unsqueeze(2) 359 | video = torch.cat([first_frames, video], dim=2) 360 | 361 | video = (video / 2 + 0.5).clamp(0, 1) 362 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 363 | video = video.cpu().float().numpy() 364 | return video 365 | 366 | def prepare_extra_step_kwargs(self, generator, eta): 367 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 368 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 369 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 370 | # and should be between [0, 1] 371 | 372 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 373 | extra_step_kwargs = {} 374 | if accepts_eta: 375 | extra_step_kwargs["eta"] = eta 376 | 377 | # check if the scheduler accepts generator 378 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 379 | if accepts_generator: 380 | extra_step_kwargs["generator"] = generator 381 | return extra_step_kwargs 382 | 383 | def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None): 384 | if not isinstance(prompt, str) and not isinstance(prompt, list): 385 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 386 | 387 | if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)): 388 | raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}") 389 | 390 | if height % 8 != 0 or width % 8 != 0: 391 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 392 | 393 | if (callback_steps is None) or ( 394 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 395 | ): 396 | raise ValueError( 397 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 398 | f" {type(callback_steps)}." 399 | ) 400 | 401 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0): 402 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) 403 | if isinstance(generator, list) and len(generator) != batch_size: 404 | raise ValueError( 405 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 406 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 407 | ) 408 | if latents is None: 409 | rand_device = "cpu" if device.type == "mps" else device 410 | 411 | if isinstance(generator, list): 412 | # shape = shape 413 | shape = (1,) + shape[1:] 414 | if noise_sampling_method == "vanilla": 415 | latents = [ 416 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 417 | for i in range(batch_size) 418 | ] 419 | elif noise_sampling_method == "pyoco_mixed": 420 | base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) 421 | latents = [] 422 | noise_alpha_squared = noise_alpha ** 2 423 | for i in range(batch_size): 424 | base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) 425 | ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) 426 | latents.append(base_latent + ind_latent) 427 | elif noise_sampling_method == "pyoco_progressive": 428 | latents = [] 429 | noise_alpha_squared = noise_alpha ** 2 430 | for i in range(batch_size): 431 | latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 432 | ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) 433 | for j in range(1, video_length): 434 | latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :] 435 | latents.append(latent) 436 | latents = torch.cat(latents, dim=0).to(device) 437 | else: 438 | if noise_sampling_method == "vanilla": 439 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 440 | elif noise_sampling_method == "pyoco_mixed": 441 | noise_alpha_squared = noise_alpha ** 2 442 | base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) 443 | base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) 444 | ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) 445 | latents = base_latents + ind_latents 446 | elif noise_sampling_method == "pyoco_progressive": 447 | noise_alpha_squared = noise_alpha ** 2 448 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) 449 | ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared)) 450 | for j in range(1, video_length): 451 | latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :] 452 | else: 453 | if latents.shape != shape: 454 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 455 | latents = latents.to(device) 456 | 457 | # scale the initial noise by the standard deviation required by the scheduler 458 | latents = latents * self.scheduler.init_noise_sigma 459 | return latents 460 | 461 | @torch.no_grad() 462 | def __call__( 463 | self, 464 | prompt: Union[str, List[str]], 465 | video_length: Optional[int], 466 | height: Optional[int] = None, 467 | width: Optional[int] = None, 468 | num_inference_steps: int = 50, 469 | guidance_scale_txt: float = 7.5, 470 | guidance_scale_img: float = 2.0, 471 | negative_prompt: Optional[Union[str, List[str]]] = None, 472 | num_videos_per_prompt: Optional[int] = 1, 473 | eta: float = 0.0, 474 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 475 | latents: Optional[torch.FloatTensor] = None, 476 | output_type: Optional[str] = "tensor", 477 | return_dict: bool = True, 478 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 479 | callback_steps: Optional[int] = 1, 480 | # additional 481 | first_frame_paths: Optional[Union[str, List[str]]] = None, 482 | first_frames: Optional[torch.FloatTensor] = None, 483 | noise_sampling_method: str = "vanilla", 484 | noise_alpha: float = 1.0, 485 | guidance_rescale: float = 0.0, 486 | frame_stride: Optional[int] = None, 487 | use_frameinit: bool = False, 488 | frameinit_noise_level: int = 999, 489 | camera_motion: str = None, 490 | **kwargs, 491 | ): 492 | if first_frame_paths is not None and first_frames is not None: 493 | raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.") 494 | # Default height and width to unet 495 | height = height or self.unet.config.sample_size * self.vae_scale_factor 496 | width = width or self.unet.config.sample_size * self.vae_scale_factor 497 | 498 | # Check inputs. Raise error if not correct 499 | self.check_inputs(prompt, height, width, callback_steps, first_frame_paths) 500 | 501 | # Define call parameters 502 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 503 | batch_size = 1 504 | if latents is not None: 505 | batch_size = latents.shape[0] 506 | if isinstance(prompt, list): 507 | batch_size = len(prompt) 508 | first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames 509 | if first_frame_input is not None: 510 | assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length" 511 | 512 | device = self._execution_device 513 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 514 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 515 | # corresponds to doing no classifier free guidance. 516 | do_classifier_free_guidance = None 517 | # two guidance mode: text and text+image 518 | if guidance_scale_txt > 1.0: 519 | do_classifier_free_guidance = "text" 520 | if guidance_scale_img > 1.0: 521 | do_classifier_free_guidance = "both" 522 | 523 | # Encode input prompt 524 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size 525 | if negative_prompt is not None: 526 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 527 | text_embeddings = self._encode_prompt( 528 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 529 | ) 530 | 531 | # Encode input first frame 532 | first_frame_latents = None 533 | if first_frame_paths is not None: 534 | first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size 535 | if camera_motion is None: 536 | img_transform = T.Compose([ 537 | T.ToTensor(), 538 | T.Resize(height, antialias=None), 539 | T.CenterCrop((height, width)), 540 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 541 | ]) 542 | elif camera_motion == "pan_left" or camera_motion == "pan_right": 543 | img_transform = T.Compose([ 544 | T.ToTensor(), 545 | T.Resize(height, antialias=None), 546 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 547 | ]) 548 | elif camera_motion == "zoom_out" or camera_motion == "zoom_in": 549 | img_transform = T.Compose([ 550 | T.ToTensor(), 551 | T.Resize(height * 2, antialias=None), 552 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 553 | ]) 554 | 555 | first_frames = [] 556 | for first_frame_path in first_frame_paths: 557 | first_frame = Image.open(first_frame_path).convert('RGB') 558 | first_frame = img_transform(first_frame) 559 | if camera_motion is not None: 560 | if camera_motion == "pan_left": 561 | first_frame = pan_left(first_frame, num_frames=video_length, crop_width=width) 562 | elif camera_motion == "pan_right": 563 | first_frame = pan_right(first_frame, num_frames=video_length, crop_width=width) 564 | elif camera_motion == "zoom_in": 565 | first_frame = zoom_in(first_frame, num_frames=video_length, crop_width=width) 566 | elif camera_motion == "zoom_out": 567 | first_frame = zoom_out(first_frame, num_frames=video_length, crop_width=width) 568 | else: 569 | raise NotImplementedError(f"camera_motion: {camera_motion} is not implemented.") 570 | first_frames.append(first_frame.unsqueeze(0)) 571 | first_frames = torch.cat(first_frames, dim=0) 572 | if first_frames is not None: 573 | first_frames = first_frames.to(device, dtype=self.vae.dtype) 574 | if camera_motion is not None: 575 | first_frames = rearrange(first_frames, "b f c h w -> (b f) c h w") 576 | first_frame_latents = self.vae.encode(first_frames).latent_dist 577 | first_frame_latents = first_frame_latents.sample() 578 | first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w 579 | first_frame_static_vid = rearrange(first_frame_latents, "(b f) c h w -> b c f h w", f=video_length if camera_motion is not None else 1) 580 | first_frame_latents = first_frame_static_vid[:, :, 0, :, :] 581 | first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt) 582 | first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt) 583 | 584 | if use_frameinit and camera_motion is None: 585 | first_frame_static_vid = repeat(first_frame_static_vid, "b c 1 h w -> b c t h w", t=video_length) 586 | 587 | # self._progress_bar_config = {} 588 | # vid = self.decode_latents(first_frame_static_vid) 589 | # vid = torch.from_numpy(vid) 590 | # from ..utils.util import save_videos_grid 591 | # save_videos_grid(vid, "samples/debug/camera_motion/first_frame_static_vid.mp4", fps=8) 592 | 593 | # Prepare timesteps 594 | self.scheduler.set_timesteps(num_inference_steps, device=device) 595 | timesteps = self.scheduler.timesteps 596 | 597 | # Prepare latent variables 598 | num_channels_latents = self.unet.config.in_channels 599 | latents = self.prepare_latents( 600 | batch_size * num_videos_per_prompt, 601 | num_channels_latents, 602 | video_length, 603 | height, 604 | width, 605 | text_embeddings.dtype, 606 | device, 607 | generator, 608 | latents, 609 | noise_sampling_method, 610 | noise_alpha, 611 | ) 612 | latents_dtype = latents.dtype 613 | 614 | if use_frameinit: 615 | current_diffuse_timestep = frameinit_noise_level # diffuse to t noise level 616 | diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep)) 617 | diffuse_timesteps = diffuse_timesteps.long() 618 | z_T = self.scheduler.add_noise( 619 | original_samples=first_frame_static_vid.to(device), 620 | noise=latents.to(device), 621 | timesteps=diffuse_timesteps.to(device) 622 | ) 623 | latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents.to(dtype=torch.float32), LPF=self.freq_filter) 624 | latents = latents.to(dtype=latents_dtype) 625 | 626 | if first_frame_latents is not None: 627 | first_frame_noisy_latent = latents[:, :, 0, :, :] 628 | latents = latents[:, :, 1:, :, :] 629 | 630 | # Prepare extra step kwargs. 631 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 632 | 633 | # Denoising loop 634 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 635 | with self.progress_bar(total=num_inference_steps) as progress_bar: 636 | for i, t in enumerate(timesteps): 637 | # expand the latents if we are doing classifier free guidance 638 | if do_classifier_free_guidance is None: 639 | latent_model_input = latents 640 | elif do_classifier_free_guidance == "text": 641 | latent_model_input = torch.cat([latents] * 2) 642 | elif do_classifier_free_guidance == "both": 643 | latent_model_input = torch.cat([latents] * 3) 644 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 645 | if first_frame_latents is not None: 646 | if do_classifier_free_guidance is None: 647 | first_frame_latents_input = first_frame_latents 648 | elif do_classifier_free_guidance == "text": 649 | first_frame_latents_input = torch.cat([first_frame_latents] * 2) 650 | elif do_classifier_free_guidance == "both": 651 | first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents]) 652 | 653 | first_frame_latents_input = first_frame_latents_input.unsqueeze(2) 654 | 655 | # predict the noise residual 656 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype) 657 | else: 658 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) 659 | 660 | # perform guidance 661 | if do_classifier_free_guidance: 662 | if do_classifier_free_guidance == "text": 663 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 664 | noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond) 665 | elif do_classifier_free_guidance == "both": 666 | noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3) 667 | noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img) 668 | 669 | if do_classifier_free_guidance and guidance_rescale > 0.0: 670 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 671 | # currently only support text guidance 672 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 673 | 674 | # compute the previous noisy sample x_t -> x_t-1 675 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 676 | 677 | # call the callback, if provided 678 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 679 | progress_bar.update() 680 | if callback is not None and i % callback_steps == 0: 681 | callback(i, t, latents) 682 | 683 | # Post-processing 684 | latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2) 685 | # video = self.decode_latents(latents, first_frames) 686 | video = self.decode_latents(latents) 687 | 688 | # Convert to tensor 689 | if output_type == "tensor": 690 | video = torch.from_numpy(video) 691 | 692 | if not return_dict: 693 | return video 694 | 695 | return AnimationPipelineOutput(videos=video) 696 | --------------------------------------------------------------------------------