├── .gitignore ├── README.md ├── app.py ├── configs ├── black-swan.yaml ├── brown-bear.yaml ├── car-moving.yaml ├── car-turn.yaml ├── child-riding.yaml ├── cow-walking.yaml ├── dog-walking.yaml ├── horse-running.yaml ├── lion-roaring.yaml ├── man-running.yaml ├── man-surfing.yaml ├── rabbit-watermelon.yaml ├── skateboard-dog.yaml └── skateboard-man.yaml ├── data ├── black-swan.mp4 ├── brown-bear.mp4 ├── car-moving.mp4 ├── car-turn.mp4 ├── child-riding.mp4 ├── cow-walking.mp4 ├── dog-walking.mp4 ├── horse-running.mp4 ├── lion-roaring.mp4 ├── man-running.mp4 ├── man-surfing.mp4 ├── rabbit-watermelon.mp4 ├── skateboard-dog.avi └── skateboard-man.mp4 ├── docs └── vid2vid-zero.png ├── examples ├── child-riding_flooded.gif ├── child-riding_lego.gif ├── jeep-moving_Porsche.gif ├── jeep-moving_snow.gif ├── man-running_newyork.gif ├── man-running_stephen.gif ├── red-moving_desert.gif └── red-moving_snow.gif ├── gradio_demo ├── app_running.py ├── runner.py └── style.css ├── requirements.txt ├── test_vid2vid_zero.py └── vid2vid_zero ├── data └── dataset.py ├── models ├── attention_2d.py ├── resnet_2d.py ├── unet_2d_blocks.py └── unet_2d_condition.py ├── p2p ├── null_text_w_ptp.py ├── p2p_stable.py ├── ptp_utils.py └── seq_aligner.py ├── pipelines └── pipeline_vid2vid_zero.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # custom dirs 2 | checkpoints/ 3 | outputs/ 4 | 5 | # Initially taken from Github's Python gitignore files 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # tests and logs 16 | tests/fixtures/cached_*_text.txt 17 | logs/ 18 | lightning_logs/ 19 | lang_code_data/ 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | # vscode 129 | .vs 130 | .vscode 131 | 132 | # Pycharm 133 | .idea 134 | 135 | # TF code 136 | tensorflow_code 137 | 138 | # Models 139 | proc_data 140 | 141 | # examples 142 | runs 143 | /runs_old 144 | /wandb 145 | /examples/runs 146 | /examples/**/*.args 147 | /examples/rag/sweep 148 | 149 | # data 150 | /data 151 | serialization_dir 152 | 153 | # emacs 154 | *.*~ 155 | debug.env 156 | 157 | # vim 158 | .*.swp 159 | 160 | #ctags 161 | tags 162 | 163 | # pre-commit 164 | .pre-commit* 165 | 166 | # .lock 167 | *.lock 168 | 169 | # DS_Store (MacOS) 170 | .DS_Store 171 | # RL pipelines may produce mp4 outputs 172 | *.mp4 173 | 174 | # dependencies 175 | /transformers 176 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 | 6 |

vid2vid-zero for Zero-Shot Video Editing

7 | 8 |

Zero-Shot Video Editing Using Off-The-Shelf Image Diffusion Models

9 | 10 | [Wen Wang](https://scholar.google.com/citations?user=1ks0R04AAAAJ&hl=zh-CN)1*,   [Kangyang Xie](https://github.com/felix-ky)1*,   [Zide Liu](https://github.com/zideliu)1*,   [Hao Chen](https://scholar.google.com.au/citations?user=FaOqRpcAAAAJ&hl=en)1,   [Yue Cao](http://yue-cao.me/)2,   [Xinlong Wang](https://www.xloong.wang/)2,   [Chunhua Shen](https://cshen.github.io/)1 11 | 12 | 1[ZJU](https://www.zju.edu.cn/english/),   2[BAAI](https://www.baai.ac.cn/english.html) 13 | 14 |
15 | 16 | [![Hugging Face Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/BAAI/vid2vid-zero) 17 | 18 | 19 | 20 |
21 | 22 |
23 | 24 | We propose vid2vid-zero, a simple yet effective method for zero-shot video editing. Our vid2vid-zero leverages off-the-shelf image diffusion models, and doesn't require training on any video. At the core of our method is a null-text inversion module for text-to-video alignment, a cross-frame modeling module for temporal consistency, and a spatial regularization module for fidelity to the original video. Without any training, we leverage the dynamic nature of the attention mechanism to enable bi-directional temporal modeling at test time. 25 | Experiments and analyses show promising results in editing attributes, subjects, places, etc., in real-world videos. 26 | 27 | 28 | ## Highlights 29 | 30 | - Video editing with off-the-shelf image diffusion models. 31 | 32 | - No training on any video. 33 | 34 | - Promising results in editing attributes, subjects, places, etc., in real-world videos. 35 | 36 | ## News 37 | * [2023.4.12] Online Gradio Demo is available [here](https://huggingface.co/spaces/BAAI/vid2vid-zero). 38 | * [2023.4.11] Add Gradio Demo (runs in local). 39 | * [2023.4.9] Code released! 40 | 41 | ## Installation 42 | ### Requirements 43 | 44 | ```shell 45 | pip install -r requirements.txt 46 | ``` 47 | Installing [xformers](https://github.com/facebookresearch/xformers) is highly recommended for improved efficiency and speed on GPUs. 48 | 49 | ### Weights 50 | 51 | **[Stable Diffusion]** [Stable Diffusion](https://arxiv.org/abs/2112.10752) is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. The pre-trained Stable Diffusion models can be downloaded from [🤗 Hugging Face](https://huggingface.co) (e.g., [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4), [v2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1)). We use Stable Diffusion v1-4 by default. 52 | 53 | ## Zero-shot testing 54 | 55 | Simply run: 56 | 57 | ```bash 58 | accelerate launch test_vid2vid_zero.py --config path/to/config 59 | ``` 60 | 61 | For example: 62 | ```bash 63 | accelerate launch test_vid2vid_zero.py --config configs/car-moving.yaml 64 | ``` 65 | 66 | ## Gradio Demo 67 | Launch the local demo built with [gradio](https://gradio.app/): 68 | ```bash 69 | python app.py 70 | ``` 71 | 72 | Or you can use our online gradio demo [here](https://huggingface.co/spaces/BAAI/vid2vid-zero). 73 | 74 | Note that we disable Null-text Inversion and enable fp16 for faster demo response. 75 | 76 | ## Examples 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 |
Input VideoOutput VideoInput VideoOutput Video
"A car is moving on the road""A Porsche car is moving on the desert""A car is moving on the road""A jeep car is moving on the snow"
"A man is running""Stephen Curry is running in Time Square""A man is running""A man is running in New York City"
"A child is riding a bike on the road""a child is riding a bike on the flooded road""A child is riding a bike on the road""a lego child is riding a bike on the road.gif"
"A car is moving on the road""A car is moving on the snow""A car is moving on the road""A jeep car is moving on the desert"
134 | 135 | ## Citation 136 | 137 | ``` 138 | @article{vid2vid-zero, 139 | title={Zero-Shot Video Editing Using Off-The-Shelf Image Diffusion Models}, 140 | author={Wang, Wen and Xie, kangyang and Liu, Zide and Chen, Hao and Cao, Yue and Wang, Xinlong and Shen, Chunhua}, 141 | journal={arXiv preprint arXiv:2303.17599}, 142 | year={2023} 143 | } 144 | ``` 145 | 146 | ## Acknowledgement 147 | [Tune-A-Video](https://github.com/showlab/Tune-A-Video), [diffusers](https://github.com/huggingface/diffusers), [prompt-to-prompt](https://github.com/google/prompt-to-prompt). 148 | 149 | ## Contact 150 | 151 | **We are hiring** at all levels at BAAI Vision Team, including full-time researchers, engineers and interns. 152 | If you are interested in working with us on **foundation model, visual perception and multimodal learning**, please contact [Xinlong Wang](https://www.xloong.wang/) (`wangxinlong@baai.ac.cn`) and [Yue Cao](http://yue-cao.me/) (`caoyue@baai.ac.cn`). 153 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Most code is from https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-Training-UI 2 | 3 | #!/usr/bin/env python 4 | 5 | from __future__ import annotations 6 | 7 | import os 8 | from subprocess import getoutput 9 | 10 | import gradio as gr 11 | import torch 12 | 13 | from gradio_demo.app_running import create_demo 14 | from gradio_demo.runner import Runner 15 | 16 | TITLE = '# [vid2vid-zero](https://github.com/baaivision/vid2vid-zero)' 17 | 18 | ORIGINAL_SPACE_ID = 'BAAI/vid2vid-zero' 19 | SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID) 20 | GPU_DATA = getoutput('nvidia-smi') 21 | 22 | if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID: 23 | SETTINGS = f'Settings' 24 | else: 25 | SETTINGS = 'Settings' 26 | 27 | CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU. 28 |
29 | You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces. 30 | You can use "T4 small/medium" to run this demo. 31 |
32 | ''' 33 | 34 | HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run. 35 |
36 | You can check and create your Hugging Face tokens here. 37 | You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab. 38 |
39 | ''' 40 | 41 | HF_TOKEN = os.getenv('HF_TOKEN') 42 | 43 | 44 | def show_warning(warning_text: str) -> gr.Blocks: 45 | with gr.Blocks() as demo: 46 | with gr.Box(): 47 | gr.Markdown(warning_text) 48 | return demo 49 | 50 | 51 | pipe = None 52 | runner = Runner(HF_TOKEN) 53 | 54 | with gr.Blocks(css='gradio_demo/style.css') as demo: 55 | if not torch.cuda.is_available(): 56 | show_warning(CUDA_NOT_AVAILABLE_WARNING) 57 | 58 | gr.Markdown(TITLE) 59 | with gr.Tabs(): 60 | with gr.TabItem('Zero-shot Testing'): 61 | create_demo(runner, pipe) 62 | 63 | if not HF_TOKEN: 64 | show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING) 65 | 66 | demo.queue(max_size=1).launch(share=False) 67 | -------------------------------------------------------------------------------- /configs/black-swan.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/black-swan 3 | input_data: 4 | video_path: data/black-swan.mp4 5 | prompt: a blackswan is swimming on the water 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 4 11 | validation_data: 12 | prompts: 13 | - a black swan is swimming on the water, Van Gogh style 14 | - a white swan is swimming on the water 15 | video_length: 8 16 | width: 512 17 | height: 512 18 | num_inference_steps: 50 19 | guidance_scale: 7.5 20 | num_inv_steps: 50 21 | # args for null-text inv 22 | use_null_inv: True 23 | null_inner_steps: 1 24 | null_base_lr: 1e-2 25 | null_uncond_ratio: -0.5 26 | null_normal_infer: True 27 | 28 | input_batch_size: 1 29 | seed: 33 30 | mixed_precision: "no" 31 | gradient_checkpointing: True 32 | enable_xformers_memory_efficient_attention: True 33 | # test-time adaptation 34 | use_sc_attn: True 35 | use_st_attn: True 36 | st_attn_idx: 0 37 | -------------------------------------------------------------------------------- /configs/brown-bear.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/brown-bear 3 | input_data: 4 | video_path: data/brown-bear.mp4 5 | prompt: a brown bear is sitting on the ground 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 1 11 | validation_data: 12 | prompts: 13 | - a brown bear is sitting on the grass 14 | - a black bear is sitting on the grass 15 | - a polar bear is sitting on the ground 16 | video_length: 8 17 | width: 512 18 | height: 512 19 | num_inference_steps: 50 20 | guidance_scale: 7.5 21 | num_inv_steps: 50 22 | # args for null-text inv 23 | use_null_inv: True 24 | null_inner_steps: 1 25 | null_base_lr: 1e-2 26 | null_uncond_ratio: -0.5 27 | null_normal_infer: True 28 | 29 | input_batch_size: 1 30 | seed: 33 31 | mixed_precision: "no" 32 | gradient_checkpointing: True 33 | enable_xformers_memory_efficient_attention: True 34 | # test-time adaptation 35 | use_sc_attn: True 36 | use_st_attn: True 37 | st_attn_idx: 0 38 | -------------------------------------------------------------------------------- /configs/car-moving.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/car-moving 3 | input_data: 4 | video_path: data/car-moving.mp4 5 | prompt: a car is moving on the road 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 1 11 | validation_data: 12 | prompts: 13 | - a car is moving on the snow 14 | - a jeep car is moving on the road 15 | - a jeep car is moving on the desert 16 | video_length: 8 17 | width: 512 18 | height: 512 19 | num_inference_steps: 50 20 | guidance_scale: 7.5 21 | num_inv_steps: 50 22 | # args for null-text inv 23 | use_null_inv: True 24 | null_inner_steps: 1 25 | null_base_lr: 1e-2 26 | null_uncond_ratio: -0.5 27 | null_normal_infer: True 28 | 29 | input_batch_size: 1 30 | seed: 33 31 | mixed_precision: "no" 32 | gradient_checkpointing: True 33 | enable_xformers_memory_efficient_attention: True 34 | # test-time adaptation 35 | use_sc_attn: True 36 | use_st_attn: True 37 | st_attn_idx: 0 38 | -------------------------------------------------------------------------------- /configs/car-turn.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: "outputs/car-turn" 3 | 4 | input_data: 5 | video_path: "data/car-turn.mp4" 6 | prompt: "a jeep car is moving on the road" 7 | n_sample_frames: 8 8 | width: 512 9 | height: 512 10 | sample_start_idx: 0 11 | sample_frame_rate: 6 12 | 13 | validation_data: 14 | prompts: 15 | - "a jeep car is moving on the beach" 16 | - "a jeep car is moving on the snow" 17 | - "a Porsche car is moving on the desert" 18 | video_length: 8 19 | width: 512 20 | height: 512 21 | num_inference_steps: 50 22 | guidance_scale: 7.5 23 | num_inv_steps: 50 24 | # args for null-text inv 25 | use_null_inv: True 26 | null_inner_steps: 1 27 | null_base_lr: 1e-2 28 | null_uncond_ratio: -0.5 29 | null_normal_infer: True 30 | 31 | input_batch_size: 1 32 | seed: 33 33 | mixed_precision: "no" 34 | gradient_checkpointing: True 35 | enable_xformers_memory_efficient_attention: True 36 | # test-time adaptation 37 | use_sc_attn: True 38 | use_st_attn: True 39 | st_attn_idx: 0 40 | -------------------------------------------------------------------------------- /configs/child-riding.yaml: -------------------------------------------------------------------------------- 1 | 2 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 3 | output_dir: outputs/child-riding 4 | 5 | input_data: 6 | video_path: data/child-riding.mp4 7 | prompt: "a child is riding a bike on the road" 8 | n_sample_frames: 8 9 | width: 512 10 | height: 512 11 | sample_start_idx: 0 12 | sample_frame_rate: 1 13 | 14 | validation_data: 15 | # inv_latent: "outputs_2d/car-turn/inv_latents/ddim_latent-0.pt" # latent inversed w/o SCAttn ! 16 | prompts: 17 | - a lego child is riding a bike on the road 18 | - a child is riding a bike on the flooded road 19 | video_length: 8 20 | width: 512 21 | height: 512 22 | num_inference_steps: 50 23 | guidance_scale: 7.5 24 | num_inv_steps: 50 25 | # args for null-text inv 26 | use_null_inv: True 27 | null_inner_steps: 1 28 | null_base_lr: 1e-2 29 | null_uncond_ratio: -0.5 30 | null_normal_infer: True 31 | 32 | input_batch_size: 1 33 | seed: 33 34 | mixed_precision: "no" 35 | gradient_checkpointing: True 36 | enable_xformers_memory_efficient_attention: True 37 | # test-time adaptation 38 | use_sc_attn: True 39 | use_st_attn: True 40 | st_attn_idx: 0 41 | -------------------------------------------------------------------------------- /configs/cow-walking.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/cow-walking 3 | input_data: 4 | video_path: data/cow-walking.mp4 5 | prompt: a cow is walking on the grass 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 2 11 | validation_data: 12 | prompts: 13 | - a lion is walking on the grass 14 | - a dog is walking on the grass 15 | - a cow is walking on the snow 16 | video_length: 8 17 | width: 512 18 | height: 512 19 | num_inference_steps: 50 20 | guidance_scale: 7.5 21 | num_inv_steps: 50 22 | # args for null-text inv 23 | use_null_inv: True 24 | null_inner_steps: 1 25 | null_base_lr: 1e-2 26 | null_uncond_ratio: -0.5 27 | null_normal_infer: True 28 | 29 | input_batch_size: 1 30 | seed: 33 31 | mixed_precision: "no" 32 | gradient_checkpointing: True 33 | enable_xformers_memory_efficient_attention: True 34 | # test-time adaptation 35 | use_sc_attn: True 36 | use_st_attn: True 37 | st_attn_idx: 0 38 | -------------------------------------------------------------------------------- /configs/dog-walking.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/dog_walking 3 | input_data: 4 | video_path: data/dog-walking.mp4 5 | prompt: a dog is walking on the ground 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 15 10 | sample_frame_rate: 3 11 | validation_data: 12 | prompts: 13 | - a dog is walking on the ground, Van Gogh style 14 | video_length: 8 15 | width: 512 16 | height: 512 17 | num_inference_steps: 50 18 | guidance_scale: 7.5 19 | num_inv_steps: 50 20 | # args for null-text inv 21 | use_null_inv: True 22 | null_inner_steps: 1 23 | null_base_lr: 1e-2 24 | null_uncond_ratio: -0.5 25 | null_normal_infer: True 26 | 27 | input_batch_size: 1 28 | seed: 33 29 | mixed_precision: "no" 30 | gradient_checkpointing: True 31 | enable_xformers_memory_efficient_attention: True 32 | # test-time adaptation 33 | use_sc_attn: True 34 | use_st_attn: True 35 | st_attn_idx: 0 36 | -------------------------------------------------------------------------------- /configs/horse-running.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/horse-running 3 | input_data: 4 | video_path: data/horse-running.mp4 5 | prompt: a horse is running on the beach 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 2 11 | validation_data: 12 | prompts: 13 | - a dog is running on the beach 14 | - a dog is running on the desert 15 | video_length: 8 16 | width: 512 17 | height: 512 18 | num_inference_steps: 50 19 | guidance_scale: 7.5 20 | num_inv_steps: 50 21 | # args for null-text inv 22 | use_null_inv: True 23 | null_inner_steps: 1 24 | null_base_lr: 1e-2 25 | null_uncond_ratio: -0.5 26 | null_normal_infer: True 27 | 28 | input_batch_size: 1 29 | seed: 33 30 | mixed_precision: "no" 31 | gradient_checkpointing: True 32 | enable_xformers_memory_efficient_attention: True 33 | # test-time adaptation 34 | use_sc_attn: True 35 | use_st_attn: True 36 | st_attn_idx: 0 37 | -------------------------------------------------------------------------------- /configs/lion-roaring.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: ./outputs/lion-roaring 3 | input_data: 4 | video_path: data/lion-roaring.mp4 5 | prompt: a lion is roaring 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 2 11 | validation_data: 12 | prompts: 13 | - a lego lion is roaring 14 | - a wolf is roaring, anime style 15 | - a lion is roaring, anime style 16 | video_length: 8 17 | width: 512 18 | height: 512 19 | num_inference_steps: 50 20 | guidance_scale: 7.5 21 | num_inv_steps: 50 22 | # args for null-text inv 23 | use_null_inv: True 24 | null_inner_steps: 1 25 | null_base_lr: 1e-2 26 | null_uncond_ratio: -0.5 27 | null_normal_infer: True 28 | 29 | input_batch_size: 1 30 | seed: 33 31 | mixed_precision: "no" 32 | gradient_checkpointing: True 33 | enable_xformers_memory_efficient_attention: True 34 | # test-time adaptation 35 | use_sc_attn: True 36 | use_st_attn: True 37 | st_attn_idx: 0 38 | -------------------------------------------------------------------------------- /configs/man-running.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/man-running 3 | input_data: 4 | video_path: data/man-running.mp4 5 | prompt: a man is running 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 25 10 | sample_frame_rate: 2 11 | validation_data: 12 | prompts: 13 | - Stephen Curry is running in Time Square 14 | - a man is running, Van Gogh style 15 | - a man is running in New York City 16 | video_length: 8 17 | width: 512 18 | height: 512 19 | num_inference_steps: 50 20 | guidance_scale: 7.5 21 | num_inv_steps: 50 22 | # args for null-text inv 23 | use_null_inv: True 24 | null_inner_steps: 1 25 | null_base_lr: 1e-2 26 | null_uncond_ratio: -0.5 27 | null_normal_infer: True 28 | 29 | input_batch_size: 1 30 | seed: 33 31 | mixed_precision: "no" 32 | gradient_checkpointing: True 33 | enable_xformers_memory_efficient_attention: True 34 | # test-time adaptation 35 | use_sc_attn: True 36 | use_st_attn: True 37 | st_attn_idx: 0 38 | -------------------------------------------------------------------------------- /configs/man-surfing.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/man-surfing 3 | input_data: 4 | video_path: data/man-surfing.mp4 5 | prompt: a man is surfing 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 3 11 | validation_data: 12 | prompts: 13 | - a boy is surfing in the desert 14 | - Iron Man is surfing is surfing 15 | video_length: 8 16 | width: 512 17 | height: 512 18 | num_inference_steps: 50 19 | guidance_scale: 7.5 20 | num_inv_steps: 50 21 | # args for null-text inv 22 | use_null_inv: True 23 | null_inner_steps: 1 24 | null_base_lr: 1e-2 25 | null_uncond_ratio: -0.5 26 | null_normal_infer: True 27 | 28 | input_batch_size: 1 29 | seed: 33 30 | mixed_precision: "no" 31 | gradient_checkpointing: True 32 | enable_xformers_memory_efficient_attention: True 33 | # test-time adaptation 34 | use_sc_attn: True 35 | use_st_attn: True 36 | st_attn_idx: 0 37 | -------------------------------------------------------------------------------- /configs/rabbit-watermelon.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: "outputs/rabbit-watermelon" 3 | 4 | input_data: 5 | video_path: "data/rabbit-watermelon.mp4" 6 | prompt: "a rabbit is eating a watermelon" 7 | n_sample_frames: 8 8 | width: 512 9 | height: 512 10 | sample_start_idx: 0 11 | sample_frame_rate: 6 12 | 13 | validation_data: 14 | prompts: 15 | - "a tiger is eating a watermelon" 16 | - "a rabbit is eating an orange" 17 | - "a rabbit is eating a pizza" 18 | - "a puppy is eating an orange" 19 | video_length: 8 20 | width: 512 21 | height: 512 22 | num_inference_steps: 50 23 | guidance_scale: 7.5 24 | num_inv_steps: 50 25 | # args for null-text inv 26 | use_null_inv: True 27 | null_inner_steps: 1 28 | null_base_lr: 1e-2 29 | null_uncond_ratio: -0.5 30 | null_normal_infer: True 31 | 32 | input_batch_size: 1 33 | seed: 33 34 | mixed_precision: "no" 35 | gradient_checkpointing: True 36 | enable_xformers_memory_efficient_attention: True 37 | # test-time adaptation 38 | use_sc_attn: True 39 | use_st_attn: True 40 | st_attn_idx: 0 41 | -------------------------------------------------------------------------------- /configs/skateboard-dog.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/skateboard-dog 3 | input_data: 4 | video_path: data/skateboard-dog.avi 5 | prompt: A man with a dog skateboarding on the road 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 3 11 | validation_data: 12 | prompts: 13 | - A man with a dog skateboarding on the desert 14 | video_length: 8 15 | width: 512 16 | height: 512 17 | num_inference_steps: 50 18 | guidance_scale: 7.5 19 | num_inv_steps: 50 20 | # args for null-text inv 21 | use_null_inv: True 22 | null_inner_steps: 1 23 | null_base_lr: 1e-2 24 | null_uncond_ratio: -0.5 25 | null_normal_infer: True 26 | 27 | input_batch_size: 1 28 | seed: 33 29 | mixed_precision: "no" 30 | gradient_checkpointing: True 31 | enable_xformers_memory_efficient_attention: True 32 | # test-time adaptation 33 | use_sc_attn: True 34 | use_st_attn: True 35 | st_attn_idx: 0 36 | -------------------------------------------------------------------------------- /configs/skateboard-man.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: checkpoints/stable-diffusion-v1-4 2 | output_dir: outputs/skateboard-man 3 | input_data: 4 | video_path: data/skateboard-man.mp4 5 | prompt: a man is playing skateboard on the ground 6 | n_sample_frames: 8 7 | width: 512 8 | height: 512 9 | sample_start_idx: 0 10 | sample_frame_rate: 3 11 | validation_data: 12 | prompts: 13 | - a boy is playing skateboard on the ground 14 | video_length: 8 15 | width: 512 16 | height: 512 17 | num_inference_steps: 50 18 | guidance_scale: 7.5 19 | num_inv_steps: 50 20 | # args for null-text inv 21 | use_null_inv: True 22 | null_inner_steps: 1 23 | null_base_lr: 1e-2 24 | null_uncond_ratio: -0.5 25 | null_normal_infer: True 26 | 27 | input_batch_size: 1 28 | seed: 33 29 | mixed_precision: "no" 30 | gradient_checkpointing: True 31 | enable_xformers_memory_efficient_attention: True 32 | # test-time adaptation 33 | use_sc_attn: True 34 | use_st_attn: True 35 | st_attn_idx: 0 36 | -------------------------------------------------------------------------------- /data/black-swan.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/black-swan.mp4 -------------------------------------------------------------------------------- /data/brown-bear.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/brown-bear.mp4 -------------------------------------------------------------------------------- /data/car-moving.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/car-moving.mp4 -------------------------------------------------------------------------------- /data/car-turn.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/car-turn.mp4 -------------------------------------------------------------------------------- /data/child-riding.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/child-riding.mp4 -------------------------------------------------------------------------------- /data/cow-walking.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/cow-walking.mp4 -------------------------------------------------------------------------------- /data/dog-walking.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/dog-walking.mp4 -------------------------------------------------------------------------------- /data/horse-running.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/horse-running.mp4 -------------------------------------------------------------------------------- /data/lion-roaring.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/lion-roaring.mp4 -------------------------------------------------------------------------------- /data/man-running.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/man-running.mp4 -------------------------------------------------------------------------------- /data/man-surfing.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/man-surfing.mp4 -------------------------------------------------------------------------------- /data/rabbit-watermelon.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/rabbit-watermelon.mp4 -------------------------------------------------------------------------------- /data/skateboard-dog.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/skateboard-dog.avi -------------------------------------------------------------------------------- /data/skateboard-man.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/data/skateboard-man.mp4 -------------------------------------------------------------------------------- /docs/vid2vid-zero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/docs/vid2vid-zero.png -------------------------------------------------------------------------------- /examples/child-riding_flooded.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/examples/child-riding_flooded.gif -------------------------------------------------------------------------------- /examples/child-riding_lego.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/examples/child-riding_lego.gif -------------------------------------------------------------------------------- /examples/jeep-moving_Porsche.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/examples/jeep-moving_Porsche.gif -------------------------------------------------------------------------------- /examples/jeep-moving_snow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/examples/jeep-moving_snow.gif -------------------------------------------------------------------------------- /examples/man-running_newyork.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/examples/man-running_newyork.gif -------------------------------------------------------------------------------- /examples/man-running_stephen.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/examples/man-running_stephen.gif -------------------------------------------------------------------------------- /examples/red-moving_desert.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/examples/red-moving_desert.gif -------------------------------------------------------------------------------- /examples/red-moving_snow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/vid2vid-zero/a3e41c0f156253afa44388396ecd7452eb414607/examples/red-moving_snow.gif -------------------------------------------------------------------------------- /gradio_demo/app_running.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | 7 | import gradio as gr 8 | 9 | from gradio_demo.runner import Runner 10 | 11 | 12 | def create_demo(runner: Runner, 13 | pipe: None = None) -> gr.Blocks: 14 | hf_token = os.getenv('HF_TOKEN') 15 | with gr.Blocks() as demo: 16 | with gr.Row(): 17 | with gr.Column(): 18 | with gr.Box(): 19 | gr.Markdown('Input Data') 20 | input_video = gr.File(label='Input video') 21 | input_prompt = gr.Textbox( 22 | label='Input prompt', 23 | max_lines=1, 24 | placeholder='A car is moving on the road.') 25 | gr.Markdown(''' 26 | - Upload a video and write a `Input Prompt` that describes the video. 27 | ''') 28 | 29 | with gr.Column(): 30 | with gr.Box(): 31 | gr.Markdown('Input Parameters') 32 | with gr.Row(): 33 | model_path = gr.Text( 34 | label='Path to off-the-shelf model', 35 | value='CompVis/stable-diffusion-v1-4', 36 | max_lines=1) 37 | resolution = gr.Dropdown(choices=['512', '768'], 38 | value='512', 39 | label='Resolution', 40 | visible=False) 41 | 42 | with gr.Accordion('Advanced settings', open=False): 43 | sample_start_idx = gr.Number( 44 | label='Start Frame Index',value=0) 45 | sample_frame_rate = gr.Number( 46 | label='Frame Rate',value=1) 47 | n_sample_frames = gr.Number( 48 | label='Number of Frames',value=8) 49 | guidance_scale = gr.Number( 50 | label='Guidance Scale', value=7.5) 51 | seed = gr.Slider(label='Seed', 52 | minimum=0, 53 | maximum=100000, 54 | step=1, 55 | randomize=True, 56 | value=33) 57 | input_token = gr.Text(label='Hugging Face Write Token', 58 | placeholder='', 59 | visible=False if hf_token else True) 60 | gr.Markdown(''' 61 | - Upload input video or choose an exmple blow 62 | - Set hyperparameters & click start 63 | - It takes a few minutes to download model first 64 | ''') 65 | 66 | with gr.Row(): 67 | with gr.Column(): 68 | validation_prompt = gr.Text( 69 | label='Validation Prompt', 70 | placeholder= 71 | 'prompt to test the model, e.g: a Lego man is surfing') 72 | 73 | remove_gpu_after_running = gr.Checkbox( 74 | label='Remove GPU after running', 75 | value=False, 76 | interactive=bool(os.getenv('SPACE_ID')), 77 | visible=False) 78 | 79 | with gr.Row(): 80 | result = gr.Video(label='Result') 81 | 82 | # examples 83 | with gr.Row(): 84 | examples = [ 85 | [ 86 | 'CompVis/stable-diffusion-v1-4', 87 | "data/car-moving.mp4", 88 | 'A car is moving on the road.', 89 | 8, 0, 1, 90 | 'A jeep car is moving on the desert.', 91 | 7.5, 512, 33, 92 | False, None, 93 | ], 94 | 95 | [ 96 | 'CompVis/stable-diffusion-v1-4', 97 | "data/black-swan.mp4", 98 | 'A blackswan is swimming on the water.', 99 | 8, 0, 4, 100 | 'A white swan is swimming on the water.', 101 | 7.5, 512, 33, 102 | False, None, 103 | ], 104 | 105 | [ 106 | 'CompVis/stable-diffusion-v1-4', 107 | "data/child-riding.mp4", 108 | 'A child is riding a bike on the road.', 109 | 8, 0, 1, 110 | 'A lego child is riding a bike on the road.', 111 | 7.5, 512, 33, 112 | False, None, 113 | ], 114 | 115 | [ 116 | 'CompVis/stable-diffusion-v1-4', 117 | "data/car-turn.mp4", 118 | 'A jeep car is moving on the road.', 119 | 8, 0, 6, 120 | 'A jeep car is moving on the snow.', 121 | 7.5, 512, 33, 122 | False, None, 123 | ], 124 | 125 | [ 126 | 'CompVis/stable-diffusion-v1-4', 127 | "data/rabbit-watermelon.mp4", 128 | 'A rabbit is eating a watermelon.', 129 | 8, 0, 6, 130 | 'A puppy is eating an orange.', 131 | 7.5, 512, 33, 132 | False, None, 133 | ], 134 | 135 | ] 136 | gr.Examples(examples=examples, 137 | fn=runner.run_vid2vid_zero, 138 | inputs=[ 139 | model_path, input_video, input_prompt, 140 | n_sample_frames, sample_start_idx, sample_frame_rate, 141 | validation_prompt, guidance_scale, resolution, seed, 142 | remove_gpu_after_running, 143 | input_token, 144 | ], 145 | outputs=result, 146 | cache_examples=os.getenv('SYSTEM') == 'spaces' 147 | ) 148 | 149 | # run 150 | run_button_vid2vid_zero = gr.Button('Start vid2vid-zero') 151 | run_button_vid2vid_zero.click( 152 | fn=runner.run_vid2vid_zero, 153 | inputs=[ 154 | model_path, input_video, input_prompt, 155 | n_sample_frames, sample_start_idx, sample_frame_rate, 156 | validation_prompt, guidance_scale, resolution, seed, 157 | remove_gpu_after_running, 158 | input_token, 159 | ], 160 | outputs=result) 161 | 162 | return demo 163 | 164 | 165 | if __name__ == '__main__': 166 | hf_token = os.getenv('HF_TOKEN') 167 | runner = Runner(hf_token) 168 | demo = create_demo(runner) 169 | demo.queue(max_size=1).launch(share=False) 170 | -------------------------------------------------------------------------------- /gradio_demo/runner.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import datetime 4 | import os 5 | import pathlib 6 | import shlex 7 | import shutil 8 | import subprocess 9 | import sys 10 | 11 | import gradio as gr 12 | import slugify 13 | import torch 14 | import huggingface_hub 15 | from huggingface_hub import HfApi 16 | from omegaconf import OmegaConf 17 | 18 | 19 | ORIGINAL_SPACE_ID = 'BAAI/vid2vid-zero' 20 | SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID) 21 | 22 | 23 | class Runner: 24 | def __init__(self, hf_token: str | None = None): 25 | self.hf_token = hf_token 26 | 27 | self.checkpoint_dir = pathlib.Path('checkpoints') 28 | self.checkpoint_dir.mkdir(exist_ok=True) 29 | 30 | def download_base_model(self, base_model_id: str, token=None) -> str: 31 | model_dir = self.checkpoint_dir / base_model_id 32 | org_name = base_model_id.split('/')[0] 33 | org_dir = self.checkpoint_dir / org_name 34 | if not model_dir.exists(): 35 | org_dir.mkdir(exist_ok=True) 36 | print(f'https://huggingface.co/{base_model_id}') 37 | if token == None: 38 | subprocess.run(shlex.split(f'git lfs install'), cwd=org_dir) 39 | subprocess.run(shlex.split( 40 | f'git lfs clone https://huggingface.co/{base_model_id}'), 41 | cwd=org_dir) 42 | return model_dir.as_posix() 43 | else: 44 | temp_path = huggingface_hub.snapshot_download(base_model_id, use_auth_token=token) 45 | print(temp_path, org_dir) 46 | # subprocess.run(shlex.split(f'mv {temp_path} {model_dir.as_posix()}')) 47 | # return model_dir.as_posix() 48 | return temp_path 49 | 50 | def join_model_library_org(self, token: str) -> None: 51 | subprocess.run( 52 | shlex.split( 53 | f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}' 54 | )) 55 | 56 | def run_vid2vid_zero( 57 | self, 58 | model_path: str, 59 | input_video: str, 60 | prompt: str, 61 | n_sample_frames: int, 62 | sample_start_idx: int, 63 | sample_frame_rate: int, 64 | validation_prompt: str, 65 | guidance_scale: float, 66 | resolution: str, 67 | seed: int, 68 | remove_gpu_after_running: bool, 69 | input_token: str = None, 70 | ) -> str: 71 | 72 | if not torch.cuda.is_available(): 73 | raise gr.Error('CUDA is not available.') 74 | if input_video is None: 75 | raise gr.Error('You need to upload a video.') 76 | if not prompt: 77 | raise gr.Error('The input prompt is missing.') 78 | if not validation_prompt: 79 | raise gr.Error('The validation prompt is missing.') 80 | 81 | resolution = int(resolution) 82 | n_sample_frames = int(n_sample_frames) 83 | sample_start_idx = int(sample_start_idx) 84 | sample_frame_rate = int(sample_frame_rate) 85 | 86 | repo_dir = pathlib.Path(__file__).parent 87 | prompt_path = prompt.replace(' ', '_') 88 | output_dir = repo_dir / 'outputs' / prompt_path 89 | output_dir.mkdir(parents=True, exist_ok=True) 90 | 91 | config = OmegaConf.load('configs/black-swan.yaml') 92 | config.pretrained_model_path = self.download_base_model(model_path, token=input_token) 93 | 94 | # we remove null-inversion & use fp16 for fast inference on web demo 95 | config.mixed_precision = "fp16" 96 | config.validation_data.use_null_inv = False 97 | 98 | config.output_dir = output_dir.as_posix() 99 | config.input_data.video_path = input_video.name # type: ignore 100 | config.input_data.prompt = prompt 101 | config.input_data.n_sample_frames = n_sample_frames 102 | config.input_data.width = resolution 103 | config.input_data.height = resolution 104 | config.input_data.sample_start_idx = sample_start_idx 105 | config.input_data.sample_frame_rate = sample_frame_rate 106 | 107 | config.validation_data.prompts = [validation_prompt] 108 | config.validation_data.video_length = 8 109 | config.validation_data.width = resolution 110 | config.validation_data.height = resolution 111 | config.validation_data.num_inference_steps = 50 112 | config.validation_data.guidance_scale = guidance_scale 113 | 114 | config.input_batch_size = 1 115 | config.seed = seed 116 | 117 | config_path = output_dir / 'config.yaml' 118 | with open(config_path, 'w') as f: 119 | OmegaConf.save(config, f) 120 | 121 | command = f'accelerate launch test_vid2vid_zero.py --config {config_path}' 122 | subprocess.run(shlex.split(command)) 123 | 124 | output_video_path = os.path.join(output_dir, "sample-all.mp4") 125 | print(f"video path for gradio: {output_video_path}") 126 | message = 'Running completed!' 127 | print(message) 128 | 129 | if remove_gpu_after_running: 130 | space_id = os.getenv('SPACE_ID') 131 | if space_id: 132 | api = HfApi( 133 | token=self.hf_token if self.hf_token else input_token) 134 | api.request_space_hardware(repo_id=space_id, 135 | hardware='cpu-basic') 136 | 137 | return output_video_path 138 | -------------------------------------------------------------------------------- /gradio_demo/style.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | text-align: center; 3 | } 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | torchvision==0.13.1 3 | diffusers[torch]==0.11.1 4 | transformers>=4.25.1 5 | bitsandbytes==0.35.4 6 | decord==0.6.0 7 | accelerate 8 | tensorboard 9 | modelcards 10 | omegaconf 11 | einops 12 | imageio 13 | ftfy -------------------------------------------------------------------------------- /test_vid2vid_zero.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import inspect 5 | import math 6 | import os 7 | import warnings 8 | from typing import Dict, Optional, Tuple 9 | from omegaconf import OmegaConf 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | 15 | import diffusers 16 | import transformers 17 | from accelerate import Accelerator 18 | from accelerate.logging import get_logger 19 | from accelerate.utils import set_seed 20 | from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler 21 | from diffusers.optimization import get_scheduler 22 | from diffusers.utils import check_min_version 23 | from diffusers.utils.import_utils import is_xformers_available 24 | from tqdm.auto import tqdm 25 | from transformers import CLIPTextModel, CLIPTokenizer 26 | 27 | from vid2vid_zero.models.unet_2d_condition import UNet2DConditionModel 28 | from vid2vid_zero.data.dataset import VideoDataset 29 | from vid2vid_zero.pipelines.pipeline_vid2vid_zero import Vid2VidZeroPipeline 30 | from vid2vid_zero.util import save_videos_grid, save_videos_as_images, ddim_inversion 31 | from einops import rearrange 32 | 33 | from vid2vid_zero.p2p.p2p_stable import AttentionReplace, AttentionRefine 34 | from vid2vid_zero.p2p.ptp_utils import register_attention_control 35 | from vid2vid_zero.p2p.null_text_w_ptp import NullInversion 36 | 37 | 38 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 39 | check_min_version("0.10.0.dev0") 40 | 41 | logger = get_logger(__name__, log_level="INFO") 42 | 43 | 44 | def prepare_control(unet, prompts, validation_data): 45 | assert len(prompts) == 2 46 | 47 | print(prompts[0]) 48 | print(prompts[1]) 49 | length1 = len(prompts[0].split(' ')) 50 | length2 = len(prompts[1].split(' ')) 51 | if length1 == length2: 52 | # prepare for attn guidance 53 | cross_replace_steps = 0.8 54 | self_replace_steps = 0.4 55 | controller = AttentionReplace(prompts, validation_data['num_inference_steps'], 56 | cross_replace_steps=cross_replace_steps, 57 | self_replace_steps=self_replace_steps) 58 | else: 59 | cross_replace_steps = 0.8 60 | self_replace_steps = 0.4 61 | controller = AttentionRefine(prompts, validation_data['num_inference_steps'], 62 | cross_replace_steps=self_replace_steps, 63 | self_replace_steps=self_replace_steps) 64 | 65 | print(controller) 66 | register_attention_control(unet, controller) 67 | 68 | # the update of unet forward function is inplace 69 | return cross_replace_steps, self_replace_steps 70 | 71 | 72 | def main( 73 | pretrained_model_path: str, 74 | output_dir: str, 75 | input_data: Dict, 76 | validation_data: Dict, 77 | input_batch_size: int = 1, 78 | gradient_accumulation_steps: int = 1, 79 | gradient_checkpointing: bool = True, 80 | mixed_precision: Optional[str] = "fp16", 81 | enable_xformers_memory_efficient_attention: bool = True, 82 | seed: Optional[int] = None, 83 | use_sc_attn: bool = True, 84 | use_st_attn: bool = True, 85 | st_attn_idx: int = 0, 86 | fps: int = 8, 87 | ): 88 | *_, config = inspect.getargvalues(inspect.currentframe()) 89 | 90 | accelerator = Accelerator( 91 | gradient_accumulation_steps=gradient_accumulation_steps, 92 | mixed_precision=mixed_precision, 93 | ) 94 | 95 | # Make one log on every process with the configuration for debugging. 96 | logging.basicConfig( 97 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 98 | datefmt="%m/%d/%Y %H:%M:%S", 99 | level=logging.INFO, 100 | ) 101 | logger.info(accelerator.state, main_process_only=False) 102 | if accelerator.is_local_main_process: 103 | transformers.utils.logging.set_verbosity_warning() 104 | diffusers.utils.logging.set_verbosity_info() 105 | else: 106 | transformers.utils.logging.set_verbosity_error() 107 | diffusers.utils.logging.set_verbosity_error() 108 | 109 | # If passed along, set the training seed now. 110 | if seed is not None: 111 | set_seed(seed) 112 | 113 | # Handle the output folder creation 114 | if accelerator.is_main_process: 115 | os.makedirs(output_dir, exist_ok=True) 116 | os.makedirs(f"{output_dir}/sample", exist_ok=True) 117 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) 118 | 119 | # Load tokenizer and models. 120 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 121 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 122 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 123 | unet = UNet2DConditionModel.from_pretrained( 124 | pretrained_model_path, subfolder="unet", use_sc_attn=use_sc_attn, 125 | use_st_attn=use_st_attn, st_attn_idx=st_attn_idx) 126 | 127 | # Freeze vae, text_encoder, and unet 128 | vae.requires_grad_(False) 129 | text_encoder.requires_grad_(False) 130 | unet.requires_grad_(False) 131 | 132 | if enable_xformers_memory_efficient_attention: 133 | if is_xformers_available(): 134 | unet.enable_xformers_memory_efficient_attention() 135 | else: 136 | raise ValueError("xformers is not available. Make sure it is installed correctly") 137 | 138 | if gradient_checkpointing: 139 | unet.enable_gradient_checkpointing() 140 | 141 | # Get the training dataset 142 | input_dataset = VideoDataset(**input_data) 143 | 144 | # Preprocessing the dataset 145 | input_dataset.prompt_ids = tokenizer( 146 | input_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 147 | ).input_ids[0] 148 | 149 | # DataLoaders creation: 150 | input_dataloader = torch.utils.data.DataLoader( 151 | input_dataset, batch_size=input_batch_size 152 | ) 153 | 154 | # Get the validation pipeline 155 | validation_pipeline = Vid2VidZeroPipeline( 156 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 157 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler"), 158 | safety_checker=None, feature_extractor=None, 159 | ) 160 | validation_pipeline.enable_vae_slicing() 161 | ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler') 162 | ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps) 163 | 164 | # Prepare everything with our `accelerator`. 165 | unet, input_dataloader = accelerator.prepare( 166 | unet, input_dataloader, 167 | ) 168 | 169 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 170 | # as these models are only used for inference, keeping weights in full precision is not required. 171 | weight_dtype = torch.float32 172 | if accelerator.mixed_precision == "fp16": 173 | weight_dtype = torch.float16 174 | elif accelerator.mixed_precision == "bf16": 175 | weight_dtype = torch.bfloat16 176 | 177 | # Move text_encode and vae to gpu and cast to weight_dtype 178 | text_encoder.to(accelerator.device, dtype=weight_dtype) 179 | vae.to(accelerator.device, dtype=weight_dtype) 180 | 181 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 182 | num_update_steps_per_epoch = math.ceil(len(input_dataloader) / gradient_accumulation_steps) 183 | 184 | # We need to initialize the trackers we use, and also store our configuration. 185 | # The trackers initializes automatically on the main process. 186 | if accelerator.is_main_process: 187 | accelerator.init_trackers("vid2vid-zero") 188 | 189 | # Zero-shot Eval! 190 | total_batch_size = input_batch_size * accelerator.num_processes * gradient_accumulation_steps 191 | 192 | logger.info("***** Running training *****") 193 | logger.info(f" Num examples = {len(input_dataset)}") 194 | logger.info(f" Instantaneous batch size per device = {input_batch_size}") 195 | logger.info(f" Total input batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 196 | global_step = 0 197 | 198 | unet.eval() 199 | for step, batch in enumerate(input_dataloader): 200 | samples = [] 201 | pixel_values = batch["pixel_values"].to(weight_dtype) 202 | # save input video 203 | video = (pixel_values / 2 + 0.5).clamp(0, 1).detach().cpu() 204 | video = video.permute(0, 2, 1, 3, 4) # (b, f, c, h, w) 205 | samples.append(video) 206 | # start processing 207 | video_length = pixel_values.shape[1] 208 | pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") 209 | latents = vae.encode(pixel_values).latent_dist.sample() 210 | # take video as input 211 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) 212 | latents = latents * 0.18215 213 | 214 | generator = torch.Generator(device="cuda") 215 | generator.manual_seed(seed) 216 | 217 | # perform inversion 218 | ddim_inv_latent = None 219 | if validation_data.use_null_inv: 220 | null_inversion = NullInversion( 221 | model=validation_pipeline, guidance_scale=validation_data.guidance_scale, null_inv_with_prompt=False, 222 | null_normal_infer=validation_data.null_normal_infer, 223 | ) 224 | ddim_inv_latent, uncond_embeddings = null_inversion.invert( 225 | latents, input_dataset.prompt, verbose=True, 226 | null_inner_steps=validation_data.null_inner_steps, 227 | null_base_lr=validation_data.null_base_lr, 228 | ) 229 | ddim_inv_latent = ddim_inv_latent.to(weight_dtype) 230 | uncond_embeddings = [embed.to(weight_dtype) for embed in uncond_embeddings] 231 | else: 232 | ddim_inv_latent = ddim_inversion( 233 | validation_pipeline, ddim_inv_scheduler, video_latent=latents, 234 | num_inv_steps=validation_data.num_inv_steps, prompt="", 235 | normal_infer=True, # we don't want to use scatn or denseattn for inversion, just use sd inferenece 236 | )[-1].to(weight_dtype) 237 | uncond_embeddings = None 238 | 239 | ddim_inv_latent = ddim_inv_latent.repeat(2, 1, 1, 1, 1) 240 | 241 | for idx, prompt in enumerate(validation_data.prompts): 242 | prompts = [input_dataset.prompt, prompt] # a list of two prompts 243 | cross_replace_steps, self_replace_steps = prepare_control(unet=unet, prompts=prompts, validation_data=validation_data) 244 | 245 | sample = validation_pipeline(prompts, generator=generator, latents=ddim_inv_latent, 246 | uncond_embeddings=uncond_embeddings, 247 | **validation_data).images 248 | 249 | assert sample.shape[0] == 2 250 | sample_inv, sample_gen = sample.chunk(2) 251 | # add input for vis 252 | save_videos_grid(sample_gen, f"{output_dir}/sample/{prompts[1]}.gif", fps=fps) 253 | samples.append(sample_gen) 254 | 255 | samples = torch.concat(samples) 256 | save_path = f"{output_dir}/sample-all.gif" 257 | save_videos_grid(samples, save_path, fps=fps) 258 | save_videos_grid(samples, save_path.replace(".gif", ".mp4"), fps=fps) # .mp4 format for gradio 259 | logger.info(f"Saved samples to {save_path}") 260 | 261 | 262 | if __name__ == "__main__": 263 | parser = argparse.ArgumentParser() 264 | parser.add_argument("--config", type=str, default="./configs/vid2vid_zero.yaml") 265 | args = parser.parse_args() 266 | 267 | main(**OmegaConf.load(args.config)) 268 | -------------------------------------------------------------------------------- /vid2vid_zero/data/dataset.py: -------------------------------------------------------------------------------- 1 | import decord 2 | decord.bridge.set_bridge('torch') 3 | 4 | from torch.utils.data import Dataset 5 | from einops import rearrange 6 | 7 | 8 | class VideoDataset(Dataset): 9 | def __init__( 10 | self, 11 | video_path: str, 12 | prompt: str, 13 | width: int = 512, 14 | height: int = 512, 15 | n_sample_frames: int = 8, 16 | sample_start_idx: int = 0, 17 | sample_frame_rate: int = 1, 18 | ): 19 | self.video_path = video_path 20 | self.prompt = prompt 21 | self.prompt_ids = None 22 | 23 | self.width = width 24 | self.height = height 25 | self.n_sample_frames = n_sample_frames 26 | self.sample_start_idx = sample_start_idx 27 | self.sample_frame_rate = sample_frame_rate 28 | 29 | def __len__(self): 30 | return 1 31 | 32 | def __getitem__(self, index): 33 | # load and sample video frames 34 | vr = decord.VideoReader(self.video_path, width=self.width, height=self.height) 35 | sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] 36 | video = vr.get_batch(sample_index) 37 | video = rearrange(video, "f h w c -> f c h w") 38 | 39 | example = { 40 | "pixel_values": (video / 127.5 - 1.0), 41 | "prompt_ids": self.prompt_ids 42 | } 43 | 44 | return example 45 | -------------------------------------------------------------------------------- /vid2vid_zero/models/attention_2d.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.modeling_utils import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm 15 | 16 | from einops import rearrange, repeat 17 | 18 | 19 | @dataclass 20 | class Transformer2DModelOutput(BaseOutput): 21 | sample: torch.FloatTensor 22 | 23 | 24 | if is_xformers_available(): 25 | import xformers 26 | import xformers.ops 27 | else: 28 | xformers = None 29 | 30 | 31 | class Transformer2DModel(ModelMixin, ConfigMixin): 32 | @register_to_config 33 | def __init__( 34 | self, 35 | num_attention_heads: int = 16, 36 | attention_head_dim: int = 88, 37 | in_channels: Optional[int] = None, 38 | num_layers: int = 1, 39 | dropout: float = 0.0, 40 | norm_num_groups: int = 32, 41 | cross_attention_dim: Optional[int] = None, 42 | attention_bias: bool = False, 43 | sample_size: Optional[int] = None, 44 | num_vector_embeds: Optional[int] = None, 45 | activation_fn: str = "geglu", 46 | num_embeds_ada_norm: Optional[int] = None, 47 | use_linear_projection: bool = False, 48 | only_cross_attention: bool = False, 49 | upcast_attention: bool = False, 50 | use_sc_attn: bool = False, 51 | use_st_attn: bool = False, 52 | ): 53 | super().__init__() 54 | self.use_linear_projection = use_linear_projection 55 | self.num_attention_heads = num_attention_heads 56 | self.attention_head_dim = attention_head_dim 57 | inner_dim = num_attention_heads * attention_head_dim 58 | 59 | # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 60 | # Define whether input is continuous or discrete depending on configuration 61 | self.is_input_continuous = in_channels is not None 62 | self.is_input_vectorized = num_vector_embeds is not None 63 | 64 | if self.is_input_continuous and self.is_input_vectorized: 65 | raise ValueError( 66 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 67 | " sure that either `in_channels` or `num_vector_embeds` is None." 68 | ) 69 | elif not self.is_input_continuous and not self.is_input_vectorized: 70 | raise ValueError( 71 | f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" 72 | " sure that either `in_channels` or `num_vector_embeds` is not None." 73 | ) 74 | 75 | # 2. Define input layers 76 | if self.is_input_continuous: 77 | self.in_channels = in_channels 78 | 79 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 80 | if use_linear_projection: 81 | self.proj_in = nn.Linear(in_channels, inner_dim) 82 | else: 83 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 84 | else: 85 | raise NotImplementedError 86 | 87 | # Define transformers blocks 88 | self.transformer_blocks = nn.ModuleList( 89 | [ 90 | BasicTransformerBlock( 91 | inner_dim, 92 | num_attention_heads, 93 | attention_head_dim, 94 | dropout=dropout, 95 | cross_attention_dim=cross_attention_dim, 96 | activation_fn=activation_fn, 97 | num_embeds_ada_norm=num_embeds_ada_norm, 98 | attention_bias=attention_bias, 99 | only_cross_attention=only_cross_attention, 100 | upcast_attention=upcast_attention, 101 | use_sc_attn=use_sc_attn, 102 | use_st_attn=True if (d == 0 and use_st_attn) else False , 103 | ) 104 | for d in range(num_layers) 105 | ] 106 | ) 107 | 108 | # 4. Define output layers 109 | if use_linear_projection: 110 | self.proj_out = nn.Linear(in_channels, inner_dim) 111 | else: 112 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 113 | 114 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True, normal_infer: bool = False): 115 | # Input 116 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 117 | video_length = hidden_states.shape[2] 118 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 119 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 120 | 121 | batch, channel, height, weight = hidden_states.shape 122 | residual = hidden_states 123 | 124 | hidden_states = self.norm(hidden_states) 125 | if not self.use_linear_projection: 126 | hidden_states = self.proj_in(hidden_states) 127 | inner_dim = hidden_states.shape[1] 128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 129 | else: 130 | inner_dim = hidden_states.shape[1] 131 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 132 | hidden_states = self.proj_in(hidden_states) 133 | 134 | # Blocks 135 | for block in self.transformer_blocks: 136 | hidden_states = block( 137 | hidden_states, 138 | encoder_hidden_states=encoder_hidden_states, 139 | timestep=timestep, 140 | video_length=video_length, 141 | normal_infer=normal_infer, 142 | ) 143 | 144 | # Output 145 | if not self.use_linear_projection: 146 | hidden_states = ( 147 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 148 | ) 149 | hidden_states = self.proj_out(hidden_states) 150 | else: 151 | hidden_states = self.proj_out(hidden_states) 152 | hidden_states = ( 153 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 154 | ) 155 | 156 | output = hidden_states + residual 157 | 158 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 159 | if not return_dict: 160 | return (output,) 161 | 162 | return Transformer2DModelOutput(sample=output) 163 | 164 | 165 | class BasicTransformerBlock(nn.Module): 166 | def __init__( 167 | self, 168 | dim: int, 169 | num_attention_heads: int, 170 | attention_head_dim: int, 171 | dropout=0.0, 172 | cross_attention_dim: Optional[int] = None, 173 | activation_fn: str = "geglu", 174 | num_embeds_ada_norm: Optional[int] = None, 175 | attention_bias: bool = False, 176 | only_cross_attention: bool = False, 177 | upcast_attention: bool = False, 178 | use_sc_attn: bool = False, 179 | use_st_attn: bool = False, 180 | ): 181 | super().__init__() 182 | self.only_cross_attention = only_cross_attention 183 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 184 | 185 | # Attn with temporal modeling 186 | self.use_sc_attn = use_sc_attn 187 | self.use_st_attn = use_st_attn 188 | 189 | attn_type = SparseCausalAttention if self.use_sc_attn else CrossAttention 190 | attn_type = SpatialTemporalAttention if self.use_st_attn else attn_type 191 | self.attn1 = attn_type( 192 | query_dim=dim, 193 | heads=num_attention_heads, 194 | dim_head=attention_head_dim, 195 | dropout=dropout, 196 | bias=attention_bias, 197 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 198 | upcast_attention=upcast_attention, 199 | ) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 201 | 202 | # Cross-Attn 203 | if cross_attention_dim is not None: 204 | self.attn2 = CrossAttention( 205 | query_dim=dim, 206 | cross_attention_dim=cross_attention_dim, 207 | heads=num_attention_heads, 208 | dim_head=attention_head_dim, 209 | dropout=dropout, 210 | bias=attention_bias, 211 | upcast_attention=upcast_attention, 212 | ) # is self-attn if encoder_hidden_states is none 213 | else: 214 | self.attn2 = None 215 | 216 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 217 | 218 | if cross_attention_dim is not None: 219 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 220 | else: 221 | self.norm2 = None 222 | 223 | # 3. Feed-forward 224 | self.norm3 = nn.LayerNorm(dim) 225 | 226 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 227 | if not is_xformers_available(): 228 | print("Here is how to install it") 229 | raise ModuleNotFoundError( 230 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 231 | " xformers", 232 | name="xformers", 233 | ) 234 | elif not torch.cuda.is_available(): 235 | raise ValueError( 236 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 237 | " available for GPU " 238 | ) 239 | else: 240 | try: 241 | # Make sure we can run the memory efficient attention 242 | _ = xformers.ops.memory_efficient_attention( 243 | torch.randn((1, 2, 40), device="cuda"), 244 | torch.randn((1, 2, 40), device="cuda"), 245 | torch.randn((1, 2, 40), device="cuda"), 246 | ) 247 | except Exception as e: 248 | raise e 249 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 250 | if self.attn2 is not None: 251 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 252 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 253 | 254 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, normal_infer=False): 255 | # SparseCausal-Attention 256 | norm_hidden_states = ( 257 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 258 | ) 259 | 260 | if self.only_cross_attention: 261 | hidden_states = ( 262 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 263 | ) 264 | else: 265 | if self.use_sc_attn or self.use_st_attn: 266 | hidden_states = self.attn1( 267 | norm_hidden_states, attention_mask=attention_mask, video_length=video_length, normal_infer=normal_infer, 268 | ) + hidden_states 269 | else: 270 | # shape of hidden_states: (b*f, len, dim) 271 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 272 | 273 | if self.attn2 is not None: 274 | # Cross-Attention 275 | norm_hidden_states = ( 276 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 277 | ) 278 | hidden_states = ( 279 | self.attn2( 280 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 281 | ) 282 | + hidden_states 283 | ) 284 | 285 | # Feed-forward 286 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 287 | 288 | return hidden_states 289 | 290 | 291 | class SparseCausalAttention(CrossAttention): 292 | def forward_sc_attn(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 293 | batch_size, sequence_length, _ = hidden_states.shape 294 | 295 | encoder_hidden_states = encoder_hidden_states 296 | 297 | if self.group_norm is not None: 298 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 299 | 300 | query = self.to_q(hidden_states) 301 | dim = query.shape[-1] 302 | query = self.reshape_heads_to_batch_dim(query) 303 | 304 | if self.added_kv_proj_dim is not None: 305 | raise NotImplementedError 306 | 307 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 308 | key = self.to_k(encoder_hidden_states) 309 | value = self.to_v(encoder_hidden_states) 310 | 311 | former_frame_index = torch.arange(video_length) - 1 312 | former_frame_index[0] = 0 313 | 314 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length) 315 | key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) 316 | key = rearrange(key, "b f d c -> (b f) d c") 317 | 318 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length) 319 | value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) 320 | value = rearrange(value, "b f d c -> (b f) d c") 321 | 322 | key = self.reshape_heads_to_batch_dim(key) 323 | value = self.reshape_heads_to_batch_dim(value) 324 | 325 | if attention_mask is not None: 326 | if attention_mask.shape[-1] != query.shape[1]: 327 | target_length = query.shape[1] 328 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 329 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 330 | 331 | # attention, what we cannot get enough of 332 | if self._use_memory_efficient_attention_xformers: 333 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 334 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 335 | hidden_states = hidden_states.to(query.dtype) 336 | else: 337 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 338 | hidden_states = self._attention(query, key, value, attention_mask) 339 | else: 340 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 341 | 342 | # linear proj 343 | hidden_states = self.to_out[0](hidden_states) 344 | 345 | # dropout 346 | hidden_states = self.to_out[1](hidden_states) 347 | return hidden_states 348 | 349 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, normal_infer=False): 350 | if normal_infer: 351 | return super().forward( 352 | hidden_states=hidden_states, 353 | encoder_hidden_states=encoder_hidden_states, 354 | attention_mask=attention_mask, 355 | # video_length=video_length, 356 | ) 357 | else: 358 | return self.forward_sc_attn( 359 | hidden_states=hidden_states, 360 | encoder_hidden_states=encoder_hidden_states, 361 | attention_mask=attention_mask, 362 | video_length=video_length, 363 | ) 364 | 365 | class SpatialTemporalAttention(CrossAttention): 366 | def forward_dense_attn(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 367 | batch_size, sequence_length, _ = hidden_states.shape 368 | 369 | encoder_hidden_states = encoder_hidden_states 370 | 371 | if self.group_norm is not None: 372 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 373 | 374 | query = self.to_q(hidden_states) 375 | dim = query.shape[-1] 376 | query = self.reshape_heads_to_batch_dim(query) 377 | 378 | if self.added_kv_proj_dim is not None: 379 | raise NotImplementedError 380 | 381 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 382 | key = self.to_k(encoder_hidden_states) 383 | value = self.to_v(encoder_hidden_states) 384 | 385 | key = rearrange(key, "(b f) n d -> b f n d", f=video_length) 386 | key = key.unsqueeze(1).repeat(1, video_length, 1, 1, 1) # (b f f n d) 387 | key = rearrange(key, "b f g n d -> (b f) (g n) d") 388 | 389 | value = rearrange(value, "(b f) n d -> b f n d", f=video_length) 390 | value = value.unsqueeze(1).repeat(1, video_length, 1, 1, 1) # (b f f n d) 391 | value = rearrange(value, "b f g n d -> (b f) (g n) d") 392 | 393 | key = self.reshape_heads_to_batch_dim(key) 394 | value = self.reshape_heads_to_batch_dim(value) 395 | 396 | if attention_mask is not None: 397 | if attention_mask.shape[-1] != query.shape[1]: 398 | target_length = query.shape[1] 399 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 400 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 401 | 402 | # attention, what we cannot get enough of 403 | if self._use_memory_efficient_attention_xformers: 404 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 405 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 406 | hidden_states = hidden_states.to(query.dtype) 407 | else: 408 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 409 | hidden_states = self._attention(query, key, value, attention_mask) 410 | else: 411 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 412 | 413 | # linear proj 414 | hidden_states = self.to_out[0](hidden_states) 415 | 416 | # dropout 417 | hidden_states = self.to_out[1](hidden_states) 418 | return hidden_states 419 | 420 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, normal_infer=False): 421 | if normal_infer: 422 | return super().forward( 423 | hidden_states=hidden_states, 424 | encoder_hidden_states=encoder_hidden_states, 425 | attention_mask=attention_mask, 426 | # video_length=video_length, 427 | ) 428 | else: 429 | return self.forward_dense_attn( 430 | hidden_states=hidden_states, 431 | encoder_hidden_states=encoder_hidden_states, 432 | attention_mask=attention_mask, 433 | video_length=video_length, 434 | ) 435 | -------------------------------------------------------------------------------- /vid2vid_zero/models/resnet_2d.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class Upsample2D(nn.Module): 22 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 23 | super().__init__() 24 | self.channels = channels 25 | self.out_channels = out_channels or channels 26 | self.use_conv = use_conv 27 | self.use_conv_transpose = use_conv_transpose 28 | self.name = name 29 | 30 | conv = None 31 | if use_conv_transpose: 32 | raise NotImplementedError 33 | elif use_conv: 34 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 35 | 36 | if name == "conv": 37 | self.conv = conv 38 | else: 39 | self.Conv2d_0 = conv 40 | 41 | def forward(self, hidden_states, output_size=None): 42 | assert hidden_states.shape[1] == self.channels 43 | 44 | if self.use_conv_transpose: 45 | raise NotImplementedError 46 | 47 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 48 | dtype = hidden_states.dtype 49 | if dtype == torch.bfloat16: 50 | hidden_states = hidden_states.to(torch.float32) 51 | 52 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 53 | if hidden_states.shape[0] >= 64: 54 | hidden_states = hidden_states.contiguous() 55 | 56 | # if `output_size` is passed we force the interpolation output 57 | # size and do not make use of `scale_factor=2` 58 | if output_size is None: 59 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 60 | else: 61 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 62 | 63 | # If the input is bfloat16, we cast back to bfloat16 64 | if dtype == torch.bfloat16: 65 | hidden_states = hidden_states.to(dtype) 66 | 67 | if self.use_conv: 68 | if self.name == "conv": 69 | hidden_states = self.conv(hidden_states) 70 | else: 71 | hidden_states = self.Conv2d_0(hidden_states) 72 | 73 | return hidden_states 74 | 75 | 76 | class Downsample2D(nn.Module): 77 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 78 | super().__init__() 79 | self.channels = channels 80 | self.out_channels = out_channels or channels 81 | self.use_conv = use_conv 82 | self.padding = padding 83 | stride = 2 84 | self.name = name 85 | 86 | if use_conv: 87 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 88 | else: 89 | raise NotImplementedError 90 | 91 | if name == "conv": 92 | self.Conv2d_0 = conv 93 | self.conv = conv 94 | elif name == "Conv2d_0": 95 | self.conv = conv 96 | else: 97 | self.conv = conv 98 | 99 | def forward(self, hidden_states): 100 | assert hidden_states.shape[1] == self.channels 101 | if self.use_conv and self.padding == 0: 102 | raise NotImplementedError 103 | 104 | assert hidden_states.shape[1] == self.channels 105 | hidden_states = self.conv(hidden_states) 106 | 107 | return hidden_states 108 | 109 | 110 | class ResnetBlock2D(nn.Module): 111 | def __init__( 112 | self, 113 | *, 114 | in_channels, 115 | out_channels=None, 116 | conv_shortcut=False, 117 | dropout=0.0, 118 | temb_channels=512, 119 | groups=32, 120 | groups_out=None, 121 | pre_norm=True, 122 | eps=1e-6, 123 | non_linearity="swish", 124 | time_embedding_norm="default", 125 | output_scale_factor=1.0, 126 | use_in_shortcut=None, 127 | ): 128 | super().__init__() 129 | self.pre_norm = pre_norm 130 | self.pre_norm = True 131 | self.in_channels = in_channels 132 | out_channels = in_channels if out_channels is None else out_channels 133 | self.out_channels = out_channels 134 | self.use_conv_shortcut = conv_shortcut 135 | self.time_embedding_norm = time_embedding_norm 136 | self.output_scale_factor = output_scale_factor 137 | 138 | if groups_out is None: 139 | groups_out = groups 140 | 141 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 142 | 143 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 144 | 145 | if temb_channels is not None: 146 | if self.time_embedding_norm == "default": 147 | time_emb_proj_out_channels = out_channels 148 | elif self.time_embedding_norm == "scale_shift": 149 | time_emb_proj_out_channels = out_channels * 2 150 | else: 151 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 152 | 153 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 154 | else: 155 | self.time_emb_proj = None 156 | 157 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 158 | self.dropout = torch.nn.Dropout(dropout) 159 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 160 | 161 | if non_linearity == "swish": 162 | self.nonlinearity = lambda x: F.silu(x) 163 | elif non_linearity == "mish": 164 | self.nonlinearity = Mish() 165 | elif non_linearity == "silu": 166 | self.nonlinearity = nn.SiLU() 167 | 168 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 169 | 170 | self.conv_shortcut = None 171 | if self.use_in_shortcut: 172 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 173 | 174 | def forward(self, input_tensor, temb): 175 | hidden_states = input_tensor 176 | 177 | hidden_states = self.norm1(hidden_states) 178 | hidden_states = self.nonlinearity(hidden_states) 179 | 180 | hidden_states = self.conv1(hidden_states) 181 | 182 | if temb is not None: 183 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 184 | 185 | if temb is not None and self.time_embedding_norm == "default": 186 | hidden_states = hidden_states + temb 187 | 188 | hidden_states = self.norm2(hidden_states) 189 | 190 | if temb is not None and self.time_embedding_norm == "scale_shift": 191 | scale, shift = torch.chunk(temb, 2, dim=1) 192 | hidden_states = hidden_states * (1 + scale) + shift 193 | 194 | hidden_states = self.nonlinearity(hidden_states) 195 | 196 | hidden_states = self.dropout(hidden_states) 197 | hidden_states = self.conv2(hidden_states) 198 | 199 | if self.conv_shortcut is not None: 200 | input_tensor = self.conv_shortcut(input_tensor) 201 | 202 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 203 | 204 | return output_tensor 205 | 206 | 207 | class Mish(torch.nn.Module): 208 | def forward(self, hidden_states): 209 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 210 | -------------------------------------------------------------------------------- /vid2vid_zero/models/unet_2d_blocks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .attention_2d import Transformer2DModel 7 | from .resnet_2d import Downsample2D, ResnetBlock2D, Upsample2D 8 | 9 | 10 | def get_down_block( 11 | down_block_type, 12 | num_layers, 13 | in_channels, 14 | out_channels, 15 | temb_channels, 16 | add_downsample, 17 | resnet_eps, 18 | resnet_act_fn, 19 | attn_num_head_channels, 20 | resnet_groups=None, 21 | cross_attention_dim=None, 22 | downsample_padding=None, 23 | dual_cross_attention=False, 24 | use_linear_projection=False, 25 | only_cross_attention=False, 26 | upcast_attention=False, 27 | resnet_time_scale_shift="default", 28 | use_sc_attn=False, 29 | use_st_attn=False, 30 | ): 31 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 32 | if down_block_type == "DownBlock2D": 33 | return DownBlock2D( 34 | num_layers=num_layers, 35 | in_channels=in_channels, 36 | out_channels=out_channels, 37 | temb_channels=temb_channels, 38 | add_downsample=add_downsample, 39 | resnet_eps=resnet_eps, 40 | resnet_act_fn=resnet_act_fn, 41 | resnet_groups=resnet_groups, 42 | downsample_padding=downsample_padding, 43 | resnet_time_scale_shift=resnet_time_scale_shift, 44 | ) 45 | elif down_block_type == "CrossAttnDownBlock2D": 46 | if cross_attention_dim is None: 47 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") 48 | return CrossAttnDownBlock2D( 49 | num_layers=num_layers, 50 | in_channels=in_channels, 51 | out_channels=out_channels, 52 | temb_channels=temb_channels, 53 | add_downsample=add_downsample, 54 | resnet_eps=resnet_eps, 55 | resnet_act_fn=resnet_act_fn, 56 | resnet_groups=resnet_groups, 57 | downsample_padding=downsample_padding, 58 | cross_attention_dim=cross_attention_dim, 59 | attn_num_head_channels=attn_num_head_channels, 60 | dual_cross_attention=dual_cross_attention, 61 | use_linear_projection=use_linear_projection, 62 | only_cross_attention=only_cross_attention, 63 | upcast_attention=upcast_attention, 64 | resnet_time_scale_shift=resnet_time_scale_shift, 65 | use_sc_attn=use_sc_attn, 66 | use_st_attn=use_st_attn, 67 | ) 68 | raise ValueError(f"{down_block_type} does not exist.") 69 | 70 | 71 | def get_up_block( 72 | up_block_type, 73 | num_layers, 74 | in_channels, 75 | out_channels, 76 | prev_output_channel, 77 | temb_channels, 78 | add_upsample, 79 | resnet_eps, 80 | resnet_act_fn, 81 | attn_num_head_channels, 82 | resnet_groups=None, 83 | cross_attention_dim=None, 84 | dual_cross_attention=False, 85 | use_linear_projection=False, 86 | only_cross_attention=False, 87 | upcast_attention=False, 88 | resnet_time_scale_shift="default", 89 | use_sc_attn=False, 90 | use_st_attn=False, 91 | ): 92 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 93 | if up_block_type == "UpBlock2D": 94 | return UpBlock2D( 95 | num_layers=num_layers, 96 | in_channels=in_channels, 97 | out_channels=out_channels, 98 | prev_output_channel=prev_output_channel, 99 | temb_channels=temb_channels, 100 | add_upsample=add_upsample, 101 | resnet_eps=resnet_eps, 102 | resnet_act_fn=resnet_act_fn, 103 | resnet_groups=resnet_groups, 104 | resnet_time_scale_shift=resnet_time_scale_shift, 105 | ) 106 | elif up_block_type == "CrossAttnUpBlock2D": 107 | if cross_attention_dim is None: 108 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") 109 | return CrossAttnUpBlock2D( 110 | num_layers=num_layers, 111 | in_channels=in_channels, 112 | out_channels=out_channels, 113 | prev_output_channel=prev_output_channel, 114 | temb_channels=temb_channels, 115 | add_upsample=add_upsample, 116 | resnet_eps=resnet_eps, 117 | resnet_act_fn=resnet_act_fn, 118 | resnet_groups=resnet_groups, 119 | cross_attention_dim=cross_attention_dim, 120 | attn_num_head_channels=attn_num_head_channels, 121 | dual_cross_attention=dual_cross_attention, 122 | use_linear_projection=use_linear_projection, 123 | only_cross_attention=only_cross_attention, 124 | upcast_attention=upcast_attention, 125 | resnet_time_scale_shift=resnet_time_scale_shift, 126 | use_sc_attn=use_sc_attn, 127 | use_st_attn=use_st_attn, 128 | ) 129 | raise ValueError(f"{up_block_type} does not exist.") 130 | 131 | 132 | class UNetMidBlock2DCrossAttn(nn.Module): 133 | def __init__( 134 | self, 135 | in_channels: int, 136 | temb_channels: int, 137 | dropout: float = 0.0, 138 | num_layers: int = 1, 139 | resnet_eps: float = 1e-6, 140 | resnet_time_scale_shift: str = "default", 141 | resnet_act_fn: str = "swish", 142 | resnet_groups: int = 32, 143 | resnet_pre_norm: bool = True, 144 | attn_num_head_channels=1, 145 | output_scale_factor=1.0, 146 | cross_attention_dim=1280, 147 | dual_cross_attention=False, 148 | use_linear_projection=False, 149 | upcast_attention=False, 150 | use_sc_attn=False, 151 | use_st_attn=False, 152 | ): 153 | super().__init__() 154 | 155 | self.has_cross_attention = True 156 | self.attn_num_head_channels = attn_num_head_channels 157 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 158 | 159 | # there is always at least one resnet 160 | resnets = [ 161 | ResnetBlock2D( 162 | in_channels=in_channels, 163 | out_channels=in_channels, 164 | temb_channels=temb_channels, 165 | eps=resnet_eps, 166 | groups=resnet_groups, 167 | dropout=dropout, 168 | time_embedding_norm=resnet_time_scale_shift, 169 | non_linearity=resnet_act_fn, 170 | output_scale_factor=output_scale_factor, 171 | pre_norm=resnet_pre_norm, 172 | ) 173 | ] 174 | attentions = [] 175 | 176 | for _ in range(num_layers): 177 | if dual_cross_attention: 178 | raise NotImplementedError 179 | attentions.append( 180 | Transformer2DModel( 181 | attn_num_head_channels, 182 | in_channels // attn_num_head_channels, 183 | in_channels=in_channels, 184 | num_layers=1, 185 | cross_attention_dim=cross_attention_dim, 186 | norm_num_groups=resnet_groups, 187 | use_linear_projection=use_linear_projection, 188 | upcast_attention=upcast_attention, 189 | use_sc_attn=use_sc_attn, 190 | use_st_attn=True if (use_st_attn and _ == 0) else False, 191 | ) 192 | ) 193 | resnets.append( 194 | ResnetBlock2D( 195 | in_channels=in_channels, 196 | out_channels=in_channels, 197 | temb_channels=temb_channels, 198 | eps=resnet_eps, 199 | groups=resnet_groups, 200 | dropout=dropout, 201 | time_embedding_norm=resnet_time_scale_shift, 202 | non_linearity=resnet_act_fn, 203 | output_scale_factor=output_scale_factor, 204 | pre_norm=resnet_pre_norm, 205 | ) 206 | ) 207 | 208 | self.attentions = nn.ModuleList(attentions) 209 | self.resnets = nn.ModuleList(resnets) 210 | 211 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, normal_infer=False): 212 | hidden_states = self.resnets[0](hidden_states, temb) 213 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 214 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, normal_infer=normal_infer).sample 215 | hidden_states = resnet(hidden_states, temb) 216 | 217 | return hidden_states 218 | 219 | 220 | class CrossAttnDownBlock2D(nn.Module): 221 | def __init__( 222 | self, 223 | in_channels: int, 224 | out_channels: int, 225 | temb_channels: int, 226 | dropout: float = 0.0, 227 | num_layers: int = 1, 228 | resnet_eps: float = 1e-6, 229 | resnet_time_scale_shift: str = "default", 230 | resnet_act_fn: str = "swish", 231 | resnet_groups: int = 32, 232 | resnet_pre_norm: bool = True, 233 | attn_num_head_channels=1, 234 | cross_attention_dim=1280, 235 | output_scale_factor=1.0, 236 | downsample_padding=1, 237 | add_downsample=True, 238 | dual_cross_attention=False, 239 | use_linear_projection=False, 240 | only_cross_attention=False, 241 | upcast_attention=False, 242 | use_sc_attn=False, 243 | use_st_attn=False, 244 | ): 245 | super().__init__() 246 | resnets = [] 247 | attentions = [] 248 | 249 | self.has_cross_attention = True 250 | self.attn_num_head_channels = attn_num_head_channels 251 | 252 | for i in range(num_layers): 253 | in_channels = in_channels if i == 0 else out_channels 254 | resnets.append( 255 | ResnetBlock2D( 256 | in_channels=in_channels, 257 | out_channels=out_channels, 258 | temb_channels=temb_channels, 259 | eps=resnet_eps, 260 | groups=resnet_groups, 261 | dropout=dropout, 262 | time_embedding_norm=resnet_time_scale_shift, 263 | non_linearity=resnet_act_fn, 264 | output_scale_factor=output_scale_factor, 265 | pre_norm=resnet_pre_norm, 266 | ) 267 | ) 268 | if dual_cross_attention: 269 | raise NotImplementedError 270 | attentions.append( 271 | Transformer2DModel( 272 | attn_num_head_channels, 273 | out_channels // attn_num_head_channels, 274 | in_channels=out_channels, 275 | num_layers=1, 276 | cross_attention_dim=cross_attention_dim, 277 | norm_num_groups=resnet_groups, 278 | use_linear_projection=use_linear_projection, 279 | only_cross_attention=only_cross_attention, 280 | upcast_attention=upcast_attention, 281 | use_sc_attn=use_sc_attn, 282 | use_st_attn=True if (use_st_attn and i == 0) else False, 283 | ) 284 | ) 285 | self.attentions = nn.ModuleList(attentions) 286 | self.resnets = nn.ModuleList(resnets) 287 | 288 | if add_downsample: 289 | self.downsamplers = nn.ModuleList( 290 | [ 291 | Downsample2D( 292 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 293 | ) 294 | ] 295 | ) 296 | else: 297 | self.downsamplers = None 298 | 299 | self.gradient_checkpointing = False 300 | 301 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, normal_infer=False): 302 | output_states = () 303 | 304 | for resnet, attn in zip(self.resnets, self.attentions): 305 | if self.training and self.gradient_checkpointing: 306 | 307 | def create_custom_forward(module, return_dict=None, normal_infer=False): 308 | def custom_forward(*inputs): 309 | if return_dict is not None: 310 | return module(*inputs, return_dict=return_dict, normal_infer=normal_infer) 311 | else: 312 | return module(*inputs) 313 | 314 | return custom_forward 315 | 316 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 317 | hidden_states = torch.utils.checkpoint.checkpoint( 318 | create_custom_forward(attn, return_dict=False, normal_infer=normal_infer), 319 | hidden_states, 320 | encoder_hidden_states, 321 | )[0] 322 | else: 323 | hidden_states = resnet(hidden_states, temb) 324 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, normal_infer=normal_infer).sample 325 | 326 | output_states += (hidden_states,) 327 | 328 | if self.downsamplers is not None: 329 | for downsampler in self.downsamplers: 330 | hidden_states = downsampler(hidden_states) 331 | 332 | output_states += (hidden_states,) 333 | 334 | return hidden_states, output_states 335 | 336 | 337 | class DownBlock2D(nn.Module): 338 | def __init__( 339 | self, 340 | in_channels: int, 341 | out_channels: int, 342 | temb_channels: int, 343 | dropout: float = 0.0, 344 | num_layers: int = 1, 345 | resnet_eps: float = 1e-6, 346 | resnet_time_scale_shift: str = "default", 347 | resnet_act_fn: str = "swish", 348 | resnet_groups: int = 32, 349 | resnet_pre_norm: bool = True, 350 | output_scale_factor=1.0, 351 | add_downsample=True, 352 | downsample_padding=1, 353 | ): 354 | super().__init__() 355 | resnets = [] 356 | 357 | for i in range(num_layers): 358 | in_channels = in_channels if i == 0 else out_channels 359 | resnets.append( 360 | ResnetBlock2D( 361 | in_channels=in_channels, 362 | out_channels=out_channels, 363 | temb_channels=temb_channels, 364 | eps=resnet_eps, 365 | groups=resnet_groups, 366 | dropout=dropout, 367 | time_embedding_norm=resnet_time_scale_shift, 368 | non_linearity=resnet_act_fn, 369 | output_scale_factor=output_scale_factor, 370 | pre_norm=resnet_pre_norm, 371 | ) 372 | ) 373 | 374 | self.resnets = nn.ModuleList(resnets) 375 | 376 | if add_downsample: 377 | self.downsamplers = nn.ModuleList( 378 | [ 379 | Downsample2D( 380 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 381 | ) 382 | ] 383 | ) 384 | else: 385 | self.downsamplers = None 386 | 387 | self.gradient_checkpointing = False 388 | 389 | def forward(self, hidden_states, temb=None): 390 | output_states = () 391 | 392 | for resnet in self.resnets: 393 | if self.training and self.gradient_checkpointing: 394 | 395 | def create_custom_forward(module): 396 | def custom_forward(*inputs): 397 | return module(*inputs) 398 | 399 | return custom_forward 400 | 401 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 402 | else: 403 | hidden_states = resnet(hidden_states, temb) 404 | 405 | output_states += (hidden_states,) 406 | 407 | if self.downsamplers is not None: 408 | for downsampler in self.downsamplers: 409 | hidden_states = downsampler(hidden_states) 410 | 411 | output_states += (hidden_states,) 412 | 413 | return hidden_states, output_states 414 | 415 | 416 | class CrossAttnUpBlock2D(nn.Module): 417 | def __init__( 418 | self, 419 | in_channels: int, 420 | out_channels: int, 421 | prev_output_channel: int, 422 | temb_channels: int, 423 | dropout: float = 0.0, 424 | num_layers: int = 1, 425 | resnet_eps: float = 1e-6, 426 | resnet_time_scale_shift: str = "default", 427 | resnet_act_fn: str = "swish", 428 | resnet_groups: int = 32, 429 | resnet_pre_norm: bool = True, 430 | attn_num_head_channels=1, 431 | cross_attention_dim=1280, 432 | output_scale_factor=1.0, 433 | add_upsample=True, 434 | dual_cross_attention=False, 435 | use_linear_projection=False, 436 | only_cross_attention=False, 437 | upcast_attention=False, 438 | use_sc_attn=False, 439 | use_st_attn=False, 440 | ): 441 | super().__init__() 442 | resnets = [] 443 | attentions = [] 444 | 445 | self.has_cross_attention = True 446 | self.attn_num_head_channels = attn_num_head_channels 447 | 448 | for i in range(num_layers): 449 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 450 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 451 | 452 | resnets.append( 453 | ResnetBlock2D( 454 | in_channels=resnet_in_channels + res_skip_channels, 455 | out_channels=out_channels, 456 | temb_channels=temb_channels, 457 | eps=resnet_eps, 458 | groups=resnet_groups, 459 | dropout=dropout, 460 | time_embedding_norm=resnet_time_scale_shift, 461 | non_linearity=resnet_act_fn, 462 | output_scale_factor=output_scale_factor, 463 | pre_norm=resnet_pre_norm, 464 | ) 465 | ) 466 | if dual_cross_attention: 467 | raise NotImplementedError 468 | attentions.append( 469 | Transformer2DModel( 470 | attn_num_head_channels, 471 | out_channels // attn_num_head_channels, 472 | in_channels=out_channels, 473 | num_layers=1, 474 | cross_attention_dim=cross_attention_dim, 475 | norm_num_groups=resnet_groups, 476 | use_linear_projection=use_linear_projection, 477 | only_cross_attention=only_cross_attention, 478 | upcast_attention=upcast_attention, 479 | use_sc_attn=use_sc_attn, 480 | use_st_attn=True if (use_st_attn and i == 0) else False, 481 | ) 482 | ) 483 | 484 | self.attentions = nn.ModuleList(attentions) 485 | self.resnets = nn.ModuleList(resnets) 486 | 487 | if add_upsample: 488 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 489 | else: 490 | self.upsamplers = None 491 | 492 | self.gradient_checkpointing = False 493 | 494 | def forward( 495 | self, 496 | hidden_states, 497 | res_hidden_states_tuple, 498 | temb=None, 499 | encoder_hidden_states=None, 500 | upsample_size=None, 501 | attention_mask=None, 502 | normal_infer=False, 503 | ): 504 | for resnet, attn in zip(self.resnets, self.attentions): 505 | # pop res hidden states 506 | res_hidden_states = res_hidden_states_tuple[-1] 507 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 508 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 509 | 510 | if self.training and self.gradient_checkpointing: 511 | 512 | def create_custom_forward(module, return_dict=None, normal_infer=False): 513 | def custom_forward(*inputs): 514 | if return_dict is not None: 515 | return module(*inputs, return_dict=return_dict, normal_infer=normal_infer) 516 | else: 517 | return module(*inputs) 518 | 519 | return custom_forward 520 | 521 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 522 | hidden_states = torch.utils.checkpoint.checkpoint( 523 | create_custom_forward(attn, return_dict=False, normal_infer=normal_infer), 524 | hidden_states, 525 | encoder_hidden_states, 526 | )[0] 527 | else: 528 | hidden_states = resnet(hidden_states, temb) 529 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, normal_infer=normal_infer).sample 530 | 531 | if self.upsamplers is not None: 532 | for upsampler in self.upsamplers: 533 | hidden_states = upsampler(hidden_states, upsample_size) 534 | 535 | return hidden_states 536 | 537 | 538 | class UpBlock2D(nn.Module): 539 | def __init__( 540 | self, 541 | in_channels: int, 542 | prev_output_channel: int, 543 | out_channels: int, 544 | temb_channels: int, 545 | dropout: float = 0.0, 546 | num_layers: int = 1, 547 | resnet_eps: float = 1e-6, 548 | resnet_time_scale_shift: str = "default", 549 | resnet_act_fn: str = "swish", 550 | resnet_groups: int = 32, 551 | resnet_pre_norm: bool = True, 552 | output_scale_factor=1.0, 553 | add_upsample=True, 554 | ): 555 | super().__init__() 556 | resnets = [] 557 | 558 | for i in range(num_layers): 559 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 560 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 561 | 562 | resnets.append( 563 | ResnetBlock2D( 564 | in_channels=resnet_in_channels + res_skip_channels, 565 | out_channels=out_channels, 566 | temb_channels=temb_channels, 567 | eps=resnet_eps, 568 | groups=resnet_groups, 569 | dropout=dropout, 570 | time_embedding_norm=resnet_time_scale_shift, 571 | non_linearity=resnet_act_fn, 572 | output_scale_factor=output_scale_factor, 573 | pre_norm=resnet_pre_norm, 574 | ) 575 | ) 576 | 577 | self.resnets = nn.ModuleList(resnets) 578 | 579 | if add_upsample: 580 | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) 581 | else: 582 | self.upsamplers = None 583 | 584 | self.gradient_checkpointing = False 585 | 586 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 587 | for resnet in self.resnets: 588 | # pop res hidden states 589 | res_hidden_states = res_hidden_states_tuple[-1] 590 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 591 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 592 | 593 | if self.training and self.gradient_checkpointing: 594 | 595 | def create_custom_forward(module): 596 | def custom_forward(*inputs): 597 | return module(*inputs) 598 | 599 | return custom_forward 600 | 601 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 602 | else: 603 | hidden_states = resnet(hidden_states, temb) 604 | 605 | if self.upsamplers is not None: 606 | for upsampler in self.upsamplers: 607 | hidden_states = upsampler(hidden_states, upsample_size) 608 | 609 | return hidden_states 610 | -------------------------------------------------------------------------------- /vid2vid_zero/p2p/null_text_w_ptp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Union, Tuple, List, Callable, Dict 17 | from tqdm import tqdm 18 | import torch 19 | import torch.nn.functional as nnf 20 | import numpy as np 21 | import abc 22 | from . import ptp_utils 23 | from . import seq_aligner 24 | import shutil 25 | from torch.optim.adam import Adam 26 | from PIL import Image 27 | 28 | 29 | LOW_RESOURCE = False 30 | NUM_DDIM_STEPS = 50 31 | MAX_NUM_WORDS = 77 32 | device = torch.device('cuda') 33 | from transformers import CLIPTextModel, CLIPTokenizer 34 | 35 | pretrained_model_path = "checkpoints/stable-diffusion-v1-4/" 36 | 37 | ldm_stable = None 38 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 39 | 40 | 41 | class LocalBlend: 42 | 43 | def get_mask(self, maps, alpha, use_pool): 44 | k = 1 45 | maps = (maps * alpha).sum(-1).mean(1) 46 | if use_pool: 47 | maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k)) 48 | mask = nnf.interpolate(maps, size=(x_t.shape[2:])) 49 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] 50 | mask = mask.gt(self.th[1-int(use_pool)]) 51 | mask = mask[:1] + mask 52 | return mask 53 | 54 | def __call__(self, x_t, attention_store): 55 | self.counter += 1 56 | if self.counter > self.start_blend: 57 | 58 | maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] 59 | maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps] 60 | maps = torch.cat(maps, dim=1) 61 | mask = self.get_mask(maps, self.alpha_layers, True) 62 | if self.substruct_layers is not None: 63 | maps_sub = ~self.get_mask(maps, self.substruct_layers, False) 64 | mask = mask * maps_sub 65 | mask = mask.float() 66 | x_t = x_t[:1] + mask * (x_t - x_t[:1]) 67 | return x_t 68 | 69 | def __init__(self, prompts: List[str], words: List[List[str]], substruct_words=None, start_blend=0.2, th=(.3, .3)): 70 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) 71 | for i, (prompt, words_) in enumerate(zip(prompts, words)): 72 | if type(words_) is str: 73 | words_ = [words_] 74 | for word in words_: 75 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer) 76 | alpha_layers[i, :, :, :, :, ind] = 1 77 | 78 | if substruct_words is not None: 79 | substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) 80 | for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)): 81 | if type(words_) is str: 82 | words_ = [words_] 83 | for word in words_: 84 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer) 85 | substruct_layers[i, :, :, :, :, ind] = 1 86 | self.substruct_layers = substruct_layers.to(device) 87 | else: 88 | self.substruct_layers = None 89 | self.alpha_layers = alpha_layers.to(device) 90 | self.start_blend = int(start_blend * NUM_DDIM_STEPS) 91 | self.counter = 0 92 | self.th=th 93 | 94 | 95 | class EmptyControl: 96 | 97 | 98 | def step_callback(self, x_t): 99 | return x_t 100 | 101 | def between_steps(self): 102 | return 103 | 104 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 105 | return attn 106 | 107 | 108 | class AttentionControl(abc.ABC): 109 | 110 | def step_callback(self, x_t): 111 | return x_t 112 | 113 | def between_steps(self): 114 | return 115 | 116 | @property 117 | def num_uncond_att_layers(self): 118 | return self.num_att_layers if LOW_RESOURCE else 0 119 | 120 | @abc.abstractmethod 121 | def forward (self, attn, is_cross: bool, place_in_unet: str): 122 | raise NotImplementedError 123 | 124 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 125 | if self.cur_att_layer >= self.num_uncond_att_layers: 126 | if LOW_RESOURCE: 127 | attn = self.forward(attn, is_cross, place_in_unet) 128 | else: 129 | h = attn.shape[0] 130 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 131 | self.cur_att_layer += 1 132 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 133 | self.cur_att_layer = 0 134 | self.cur_step += 1 135 | self.between_steps() 136 | return attn 137 | 138 | def reset(self): 139 | self.cur_step = 0 140 | self.cur_att_layer = 0 141 | 142 | def __init__(self): 143 | self.cur_step = 0 144 | self.num_att_layers = -1 145 | self.cur_att_layer = 0 146 | 147 | 148 | class SpatialReplace(EmptyControl): 149 | 150 | def step_callback(self, x_t): 151 | if self.cur_step < self.stop_inject: 152 | b = x_t.shape[0] 153 | x_t = x_t[:1].expand(b, *x_t.shape[1:]) 154 | return x_t 155 | 156 | def __init__(self, stop_inject: float): 157 | super(SpatialReplace, self).__init__() 158 | self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS) 159 | 160 | 161 | class AttentionStore(AttentionControl): 162 | 163 | @staticmethod 164 | def get_empty_store(): 165 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 166 | "down_self": [], "mid_self": [], "up_self": []} 167 | 168 | def forward(self, attn, is_cross: bool, place_in_unet: str): 169 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 170 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 171 | self.step_store[key].append(attn) 172 | return attn 173 | 174 | def between_steps(self): 175 | if len(self.attention_store) == 0: 176 | self.attention_store = self.step_store 177 | else: 178 | for key in self.attention_store: 179 | for i in range(len(self.attention_store[key])): 180 | self.attention_store[key][i] += self.step_store[key][i] 181 | self.step_store = self.get_empty_store() 182 | 183 | def get_average_attention(self): 184 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 185 | return average_attention 186 | 187 | 188 | def reset(self): 189 | super(AttentionStore, self).reset() 190 | self.step_store = self.get_empty_store() 191 | self.attention_store = {} 192 | 193 | def __init__(self): 194 | super(AttentionStore, self).__init__() 195 | self.step_store = self.get_empty_store() 196 | self.attention_store = {} 197 | 198 | 199 | class AttentionControlEdit(AttentionStore, abc.ABC): 200 | 201 | def step_callback(self, x_t): 202 | if self.local_blend is not None: 203 | x_t = self.local_blend(x_t, self.attention_store) 204 | return x_t 205 | 206 | def replace_self_attention(self, attn_base, att_replace, place_in_unet): 207 | if att_replace.shape[2] <= 32 ** 2: 208 | attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 209 | return attn_base 210 | else: 211 | return att_replace 212 | 213 | @abc.abstractmethod 214 | def replace_cross_attention(self, attn_base, att_replace): 215 | raise NotImplementedError 216 | 217 | def forward(self, attn, is_cross: bool, place_in_unet: str): 218 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 219 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 220 | h = attn.shape[0] // (self.batch_size) 221 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 222 | attn_base, attn_repalce = attn[0], attn[1:] 223 | if is_cross: 224 | alpha_words = self.cross_replace_alpha[self.cur_step] 225 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce 226 | attn[1:] = attn_repalce_new 227 | else: 228 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet) 229 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 230 | return attn 231 | 232 | def __init__(self, prompts, num_steps: int, 233 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 234 | self_replace_steps: Union[float, Tuple[float, float]], 235 | local_blend: Optional[LocalBlend]): 236 | super(AttentionControlEdit, self).__init__() 237 | self.batch_size = len(prompts) 238 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device) 239 | if type(self_replace_steps) is float: 240 | self_replace_steps = 0, self_replace_steps 241 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 242 | self.local_blend = local_blend 243 | 244 | class AttentionReplace(AttentionControlEdit): 245 | 246 | def replace_cross_attention(self, attn_base, att_replace): 247 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) 248 | 249 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 250 | local_blend: Optional[LocalBlend] = None): 251 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 252 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) 253 | 254 | 255 | class AttentionRefine(AttentionControlEdit): 256 | 257 | def replace_cross_attention(self, attn_base, att_replace): 258 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) 259 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 260 | # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True) 261 | return attn_replace 262 | 263 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 264 | local_blend: Optional[LocalBlend] = None): 265 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 266 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) 267 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) 268 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 269 | 270 | 271 | class AttentionReweight(AttentionControlEdit): 272 | 273 | def replace_cross_attention(self, attn_base, att_replace): 274 | if self.prev_controller is not None: 275 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) 276 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 277 | # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True) 278 | return attn_replace 279 | 280 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, 281 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): 282 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 283 | self.equalizer = equalizer.to(device) 284 | self.prev_controller = controller 285 | 286 | 287 | def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], 288 | Tuple[float, ...]]): 289 | if type(word_select) is int or type(word_select) is str: 290 | word_select = (word_select,) 291 | equalizer = torch.ones(1, 77) 292 | 293 | for word, val in zip(word_select, values): 294 | inds = ptp_utils.get_word_inds(text, word, tokenizer) 295 | equalizer[:, inds] = val 296 | return equalizer 297 | 298 | def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): 299 | out = [] 300 | attention_maps = attention_store.get_average_attention() 301 | num_pixels = res ** 2 302 | for location in from_where: 303 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 304 | if item.shape[1] == num_pixels: 305 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 306 | out.append(cross_maps) 307 | out = torch.cat(out, dim=0) 308 | out = out.sum(0) / out.shape[0] 309 | return out.cpu() 310 | 311 | 312 | def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float], self_replace_steps: float, blend_words=None, equilizer_params=None) -> AttentionControlEdit: 313 | if blend_words is None: 314 | lb = None 315 | else: 316 | lb = LocalBlend(prompts, blend_word) 317 | if is_replace_controller: 318 | controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb) 319 | else: 320 | controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb) 321 | if equilizer_params is not None: 322 | eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"]) 323 | controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, 324 | self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb, controller=controller) 325 | return controller 326 | 327 | 328 | def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0): 329 | tokens = tokenizer.encode(prompts[select]) 330 | decoder = tokenizer.decode 331 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select) 332 | images = [] 333 | for i in range(len(tokens)): 334 | image = attention_maps[:, :, i] 335 | image = 255 * image / image.max() 336 | image = image.unsqueeze(-1).expand(*image.shape, 3) 337 | image = image.numpy().astype(np.uint8) 338 | image = np.array(Image.fromarray(image).resize((256, 256))) 339 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 340 | images.append(image) 341 | ptp_utils.view_images(np.stack(images, axis=0)) 342 | 343 | 344 | class NullInversion: 345 | 346 | def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 347 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps 348 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 349 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod 350 | beta_prod_t = 1 - alpha_prod_t 351 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 352 | pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output 353 | prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction 354 | return prev_sample 355 | 356 | def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 357 | timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep 358 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod 359 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] 360 | beta_prod_t = 1 - alpha_prod_t 361 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 362 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 363 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 364 | return next_sample 365 | 366 | def get_noise_pred_single(self, latents, t, context, normal_infer=True): 367 | noise_pred = self.model.unet(latents, t, encoder_hidden_states=context, normal_infer=normal_infer)["sample"] 368 | return noise_pred 369 | 370 | def get_noise_pred(self, latents, t, is_forward=True, context=None, normal_infer=True): 371 | latents_input = torch.cat([latents] * 2) 372 | if context is None: 373 | context = self.context 374 | guidance_scale = 1 if is_forward else self.guidance_scale 375 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context, normal_infer=normal_infer)["sample"] 376 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 377 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 378 | if is_forward: 379 | latents = self.next_step(noise_pred, t, latents) 380 | else: 381 | latents = self.prev_step(noise_pred, t, latents) 382 | return latents 383 | 384 | @torch.no_grad() 385 | def latent2image(self, latents, return_type='np'): 386 | latents = 1 / 0.18215 * latents.detach() 387 | image = self.model.vae.decode(latents)['sample'] 388 | if return_type == 'np': 389 | image = (image / 2 + 0.5).clamp(0, 1) 390 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0] 391 | image = (image * 255).astype(np.uint8) 392 | return image 393 | 394 | @torch.no_grad() 395 | def image2latent(self, image): 396 | with torch.no_grad(): 397 | if type(image) is Image: 398 | image = np.array(image) 399 | if type(image) is torch.Tensor and image.dim() == 4: 400 | latents = image 401 | else: 402 | image = torch.from_numpy(image).float() / 127.5 - 1 403 | image = image.permute(2, 0, 1).unsqueeze(0).to(device) 404 | latents = self.model.vae.encode(image)['latent_dist'].mean 405 | latents = latents * 0.18215 406 | return latents 407 | 408 | @torch.no_grad() 409 | def init_prompt(self, prompt: str): 410 | uncond_input = self.model.tokenizer( 411 | [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, 412 | return_tensors="pt" 413 | ) 414 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] 415 | text_input = self.model.tokenizer( 416 | [prompt], 417 | padding="max_length", 418 | max_length=self.model.tokenizer.model_max_length, 419 | truncation=True, 420 | return_tensors="pt", 421 | ) 422 | # (1, 77, 768) 423 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] 424 | # (2, 77, 768) 425 | self.context = torch.cat([uncond_embeddings, text_embeddings]) 426 | self.prompt = prompt 427 | 428 | @torch.no_grad() 429 | def ddim_loop(self, latent): 430 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 431 | cond = cond_embeddings if self.null_inv_with_prompt else uncond_embeddings 432 | all_latent = [latent] 433 | latent = latent.clone().detach() 434 | for i in range(NUM_DDIM_STEPS): 435 | t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] 436 | noise_pred = self.get_noise_pred_single(latent, t, cond, normal_infer=True) 437 | latent = self.next_step(noise_pred, t, latent) 438 | all_latent.append(latent) 439 | return all_latent 440 | 441 | @property 442 | def scheduler(self): 443 | return self.model.scheduler 444 | 445 | @torch.no_grad() 446 | def ddim_inversion(self, latent): 447 | ddim_latents = self.ddim_loop(latent) 448 | return ddim_latents 449 | 450 | def null_optimization(self, latents, null_inner_steps, epsilon, null_base_lr=1e-2): 451 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 452 | uncond_embeddings_list = [] 453 | latent_cur = latents[-1] 454 | bar = tqdm(total=null_inner_steps * NUM_DDIM_STEPS) 455 | for i in range(NUM_DDIM_STEPS): 456 | uncond_embeddings = uncond_embeddings.clone().detach() 457 | uncond_embeddings.requires_grad = True 458 | optimizer = Adam([uncond_embeddings], lr=null_base_lr * (1. - i / 100.)) 459 | latent_prev = latents[len(latents) - i - 2] 460 | t = self.model.scheduler.timesteps[i] 461 | with torch.no_grad(): 462 | noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings, normal_infer=self.null_normal_infer) 463 | for j in range(null_inner_steps): 464 | noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings, normal_infer=self.null_normal_infer) 465 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) 466 | latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) 467 | loss = nnf.mse_loss(latents_prev_rec, latent_prev) 468 | optimizer.zero_grad() 469 | loss.backward() 470 | optimizer.step() 471 | assert not torch.isnan(uncond_embeddings.abs().mean()) 472 | loss_item = loss.item() 473 | bar.update() 474 | if loss_item < epsilon + i * 2e-5: 475 | break 476 | for j in range(j + 1, null_inner_steps): 477 | bar.update() 478 | uncond_embeddings_list.append(uncond_embeddings[:1].detach()) 479 | with torch.no_grad(): 480 | context = torch.cat([uncond_embeddings, cond_embeddings]) 481 | latent_cur = self.get_noise_pred(latent_cur, t, False, context, normal_infer=self.null_normal_infer) 482 | bar.close() 483 | return uncond_embeddings_list 484 | 485 | def invert(self, latents: torch.Tensor, prompt: str, null_inner_steps=10, early_stop_epsilon=1e-5, verbose=False, null_base_lr=1e-2): 486 | self.init_prompt(prompt) 487 | if verbose: 488 | print("DDIM inversion...") 489 | ddim_latents = self.ddim_inversion(latents.to(torch.float32)) 490 | if verbose: 491 | print("Null-text optimization...") 492 | uncond_embeddings = self.null_optimization(ddim_latents, null_inner_steps, early_stop_epsilon, null_base_lr=null_base_lr) 493 | return ddim_latents[-1], uncond_embeddings 494 | 495 | 496 | def __init__(self, model, guidance_scale, null_inv_with_prompt, null_normal_infer=True): 497 | self.null_normal_infer = null_normal_infer 498 | self.null_inv_with_prompt = null_inv_with_prompt 499 | self.guidance_scale = guidance_scale 500 | self.model = model 501 | self.tokenizer = self.model.tokenizer 502 | self.model.scheduler.set_timesteps(NUM_DDIM_STEPS) 503 | self.prompt = None 504 | self.context = None 505 | -------------------------------------------------------------------------------- /vid2vid_zero/p2p/p2p_stable.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Union, Tuple, List, Callable, Dict 16 | import torch 17 | import torch.nn.functional as nnf 18 | import numpy as np 19 | import abc 20 | from . import ptp_utils 21 | from . import seq_aligner 22 | from transformers import CLIPTextModel, CLIPTokenizer 23 | 24 | pretrained_model_path = "checkpoints/stable-diffusion-v1-4/" 25 | ldm_stable = None 26 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 27 | 28 | LOW_RESOURCE = False 29 | NUM_DIFFUSION_STEPS = 50 30 | GUIDANCE_SCALE = 7.5 31 | MAX_NUM_WORDS = 77 32 | device = torch.device('cuda') 33 | 34 | 35 | class LocalBlend: 36 | 37 | def __call__(self, x_t, attention_store): 38 | k = 1 39 | maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] 40 | maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps] 41 | maps = torch.cat(maps, dim=1) 42 | maps = (maps * self.alpha_layers).sum(-1).mean(1) 43 | mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k)) 44 | mask = nnf.interpolate(mask, size=(x_t.shape[2:])) 45 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] 46 | mask = mask.gt(self.threshold) 47 | mask = (mask[:1] + mask[1:]).float() 48 | x_t = x_t[:1] + mask * (x_t - x_t[:1]) 49 | return x_t 50 | 51 | # def __init__(self, prompts: List[str], words: [List[List[str]]], threshold=.3): 52 | def __init__(self, prompts: List[str], words: List[List[str]], threshold=.3): 53 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) 54 | for i, (prompt, words_) in enumerate(zip(prompts, words)): 55 | if type(words_) is str: 56 | words_ = [words_] 57 | for word in words_: 58 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer) 59 | alpha_layers[i, :, :, :, :, ind] = 1 60 | self.alpha_layers = alpha_layers.to(device) 61 | self.threshold = threshold 62 | 63 | 64 | class AttentionControl(abc.ABC): 65 | 66 | def step_callback(self, x_t): 67 | return x_t 68 | 69 | def between_steps(self): 70 | return 71 | 72 | @property 73 | def num_uncond_att_layers(self): 74 | return self.num_att_layers if LOW_RESOURCE else 0 75 | 76 | @abc.abstractmethod 77 | def forward (self, attn, is_cross: bool, place_in_unet: str): 78 | raise NotImplementedError 79 | 80 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 81 | if self.cur_att_layer >= self.num_uncond_att_layers: 82 | if LOW_RESOURCE: 83 | attn = self.forward(attn, is_cross, place_in_unet) 84 | else: 85 | h = attn.shape[0] 86 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 87 | self.cur_att_layer += 1 88 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 89 | self.cur_att_layer = 0 90 | self.cur_step += 1 91 | self.between_steps() 92 | return attn 93 | 94 | def reset(self): 95 | self.cur_step = 0 96 | self.cur_att_layer = 0 97 | 98 | def __init__(self): 99 | self.cur_step = 0 100 | self.num_att_layers = -1 101 | self.cur_att_layer = 0 102 | 103 | class EmptyControl(AttentionControl): 104 | 105 | def forward (self, attn, is_cross: bool, place_in_unet: str): 106 | return attn 107 | 108 | 109 | class AttentionStore(AttentionControl): 110 | 111 | @staticmethod 112 | def get_empty_store(): 113 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 114 | "down_self": [], "mid_self": [], "up_self": []} 115 | 116 | def forward(self, attn, is_cross: bool, place_in_unet: str): 117 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 118 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 119 | self.step_store[key].append(attn) 120 | return attn 121 | 122 | def between_steps(self): 123 | if len(self.attention_store) == 0: 124 | self.attention_store = self.step_store 125 | else: 126 | for key in self.attention_store: 127 | for i in range(len(self.attention_store[key])): 128 | self.attention_store[key][i] += self.step_store[key][i] 129 | self.step_store = self.get_empty_store() 130 | 131 | def get_average_attention(self): 132 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 133 | return average_attention 134 | 135 | 136 | def reset(self): 137 | super(AttentionStore, self).reset() 138 | self.step_store = self.get_empty_store() 139 | self.attention_store = {} 140 | 141 | def __init__(self): 142 | super(AttentionStore, self).__init__() 143 | self.step_store = self.get_empty_store() 144 | self.attention_store = {} 145 | 146 | 147 | class AttentionControlEdit(AttentionStore, abc.ABC): 148 | 149 | def step_callback(self, x_t): 150 | if self.local_blend is not None: 151 | x_t = self.local_blend(x_t, self.attention_store) 152 | return x_t 153 | 154 | def replace_self_attention(self, attn_base, att_replace): 155 | if att_replace.shape[2] <= 16 ** 2: 156 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 157 | else: 158 | return att_replace 159 | 160 | @abc.abstractmethod 161 | def replace_cross_attention(self, attn_base, att_replace): 162 | raise NotImplementedError 163 | 164 | def forward(self, attn, is_cross: bool, place_in_unet: str): 165 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 166 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 167 | h = attn.shape[0] // (self.batch_size) 168 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 169 | attn_base, attn_repalce = attn[0], attn[1:] 170 | if is_cross: 171 | alpha_words = self.cross_replace_alpha[self.cur_step] 172 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce 173 | attn[1:] = attn_repalce_new 174 | else: 175 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) 176 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 177 | return attn 178 | 179 | def __init__(self, prompts, num_steps: int, 180 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 181 | self_replace_steps: Union[float, Tuple[float, float]], 182 | local_blend: Optional[LocalBlend]): 183 | super(AttentionControlEdit, self).__init__() 184 | self.batch_size = len(prompts) 185 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device) 186 | if type(self_replace_steps) is float: 187 | self_replace_steps = 0, self_replace_steps 188 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 189 | self.local_blend = local_blend 190 | 191 | 192 | class AttentionReplace(AttentionControlEdit): 193 | 194 | def replace_cross_attention(self, attn_base, att_replace): 195 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) 196 | 197 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 198 | local_blend: Optional[LocalBlend] = None): 199 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 200 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) 201 | 202 | 203 | class AttentionRefine(AttentionControlEdit): 204 | 205 | def replace_cross_attention(self, attn_base, att_replace): 206 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) 207 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 208 | return attn_replace 209 | 210 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 211 | local_blend: Optional[LocalBlend] = None): 212 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 213 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) 214 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) 215 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 216 | 217 | 218 | class AttentionReweight(AttentionControlEdit): 219 | 220 | def replace_cross_attention(self, attn_base, att_replace): 221 | if self.prev_controller is not None: 222 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) 223 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 224 | return attn_replace 225 | 226 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, 227 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): 228 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 229 | self.equalizer = equalizer.to(device) 230 | self.prev_controller = controller 231 | 232 | 233 | def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], 234 | Tuple[float, ...]]): 235 | if type(word_select) is int or type(word_select) is str: 236 | word_select = (word_select,) 237 | equalizer = torch.ones(len(values), 77) 238 | values = torch.tensor(values, dtype=torch.float32) 239 | for word in word_select: 240 | inds = ptp_utils.get_word_inds(text, word, tokenizer) 241 | equalizer[:, inds] = values 242 | return equalizer 243 | -------------------------------------------------------------------------------- /vid2vid_zero/p2p/ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image, ImageDraw, ImageFont 18 | import cv2 19 | from typing import Optional, Union, Tuple, List, Callable, Dict 20 | from IPython.display import display 21 | from tqdm import tqdm 22 | import torch.nn.functional as F 23 | 24 | 25 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 26 | h, w, c = image.shape 27 | offset = int(h * .2) 28 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 29 | font = cv2.FONT_HERSHEY_SIMPLEX 30 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 31 | img[:h] = image 32 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 33 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 34 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) 35 | return img 36 | 37 | 38 | def view_images(images, num_rows=1, offset_ratio=0.02): 39 | if type(images) is list: 40 | num_empty = len(images) % num_rows 41 | elif images.ndim == 4: 42 | num_empty = images.shape[0] % num_rows 43 | else: 44 | images = [images] 45 | num_empty = 0 46 | 47 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 48 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 49 | num_items = len(images) 50 | 51 | h, w, c = images[0].shape 52 | offset = int(h * offset_ratio) 53 | num_cols = num_items // num_rows 54 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 55 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 56 | for i in range(num_rows): 57 | for j in range(num_cols): 58 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 59 | i * num_cols + j] 60 | 61 | pil_img = Image.fromarray(image_) 62 | display(pil_img) 63 | 64 | 65 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False): 66 | if low_resource: 67 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 68 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 69 | else: 70 | latents_input = torch.cat([latents] * 2) 71 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 72 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 73 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 74 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 75 | latents = controller.step_callback(latents) 76 | return latents 77 | 78 | 79 | def latent2image(vae, latents): 80 | latents = 1 / 0.18215 * latents 81 | image = vae.decode(latents)['sample'] 82 | image = (image / 2 + 0.5).clamp(0, 1) 83 | image = image.cpu().permute(0, 2, 3, 1).numpy() 84 | image = (image * 255).astype(np.uint8) 85 | return image 86 | 87 | 88 | def init_latent(latent, model, height, width, generator, batch_size): 89 | if latent is None: 90 | latent = torch.randn( 91 | (1, model.unet.in_channels, height // 8, width // 8), 92 | generator=generator, 93 | ) 94 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 95 | return latent, latents 96 | 97 | 98 | @torch.no_grad() 99 | def text2image_ldm( 100 | model, 101 | prompt: List[str], 102 | controller, 103 | num_inference_steps: int = 50, 104 | guidance_scale: Optional[float] = 7., 105 | generator: Optional[torch.Generator] = None, 106 | latent: Optional[torch.FloatTensor] = None, 107 | ): 108 | register_attention_control(model, controller) 109 | height = width = 256 110 | batch_size = len(prompt) 111 | 112 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 113 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 114 | 115 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 116 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 117 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 118 | context = torch.cat([uncond_embeddings, text_embeddings]) 119 | 120 | model.scheduler.set_timesteps(num_inference_steps) 121 | for t in tqdm(model.scheduler.timesteps): 122 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 123 | 124 | image = latent2image(model.vqvae, latents) 125 | 126 | return image, latent 127 | 128 | 129 | @torch.no_grad() 130 | def text2image_ldm_stable( 131 | model, 132 | prompt: List[str], 133 | controller, 134 | num_inference_steps: int = 50, 135 | guidance_scale: float = 7.5, 136 | generator: Optional[torch.Generator] = None, 137 | latent: Optional[torch.FloatTensor] = None, 138 | low_resource: bool = False, 139 | ): 140 | register_attention_control(model, controller) 141 | height = width = 512 142 | batch_size = len(prompt) 143 | 144 | text_input = model.tokenizer( 145 | prompt, 146 | padding="max_length", 147 | max_length=model.tokenizer.model_max_length, 148 | truncation=True, 149 | return_tensors="pt", 150 | ) 151 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 152 | max_length = text_input.input_ids.shape[-1] 153 | uncond_input = model.tokenizer( 154 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 155 | ) 156 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 157 | 158 | context = [uncond_embeddings, text_embeddings] 159 | if not low_resource: 160 | context = torch.cat(context) 161 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 162 | 163 | # set timesteps 164 | extra_set_kwargs = {"offset": 1} 165 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 166 | for t in tqdm(model.scheduler.timesteps): 167 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 168 | 169 | image = latent2image(model.vae, latents) 170 | 171 | return image, latent 172 | 173 | 174 | def register_attention_control(model, controller): 175 | 176 | def ca_forward(self, place_in_unet): 177 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): 178 | batch_size, sequence_length, _ = hidden_states.shape 179 | 180 | is_cross = encoder_hidden_states is not None 181 | encoder_hidden_states = encoder_hidden_states 182 | 183 | if self.group_norm is not None: 184 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 185 | 186 | query = self.to_q(hidden_states) 187 | # dim = query.shape[-1] 188 | query = self.reshape_heads_to_batch_dim(query) 189 | 190 | if self.added_kv_proj_dim is not None: 191 | key = self.to_k(hidden_states) 192 | value = self.to_v(hidden_states) 193 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) 194 | encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) 195 | 196 | key = self.reshape_heads_to_batch_dim(key) 197 | value = self.reshape_heads_to_batch_dim(value) 198 | encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) 199 | encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) 200 | 201 | key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) 202 | value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) 203 | else: 204 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 205 | key = self.to_k(encoder_hidden_states) 206 | value = self.to_v(encoder_hidden_states) 207 | 208 | key = self.reshape_heads_to_batch_dim(key) 209 | value = self.reshape_heads_to_batch_dim(value) 210 | 211 | if attention_mask is not None: 212 | if attention_mask.shape[-1] != query.shape[1]: 213 | target_length = query.shape[1] 214 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 215 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 216 | 217 | assert self._slice_size is None or query.shape[0] // self._slice_size == 1 218 | 219 | if self.upcast_attention: 220 | query = query.float() 221 | key = key.float() 222 | 223 | attention_scores = torch.baddbmm( 224 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 225 | query, 226 | key.transpose(-1, -2), 227 | beta=0, 228 | alpha=self.scale, 229 | ) 230 | 231 | if attention_mask is not None: 232 | attention_scores = attention_scores + attention_mask 233 | 234 | if self.upcast_softmax: 235 | attention_scores = attention_scores.float() 236 | 237 | attention_probs = attention_scores.softmax(dim=-1) 238 | 239 | # attn control 240 | attention_probs = controller(attention_probs, is_cross, place_in_unet) 241 | 242 | # cast back to the original dtype 243 | attention_probs = attention_probs.to(value.dtype) 244 | 245 | # compute attention output 246 | hidden_states = torch.bmm(attention_probs, value) 247 | 248 | # reshape hidden_states 249 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 250 | 251 | # linear proj 252 | hidden_states = self.to_out[0](hidden_states) 253 | 254 | # dropout 255 | hidden_states = self.to_out[1](hidden_states) 256 | return hidden_states 257 | 258 | return forward 259 | 260 | class DummyController: 261 | 262 | def __call__(self, *args): 263 | return args[0] 264 | 265 | def __init__(self): 266 | self.num_att_layers = 0 267 | 268 | if controller is None: 269 | controller = DummyController() 270 | 271 | def register_recr(net_, count, place_in_unet): 272 | if net_.__class__.__name__ == 'CrossAttention': 273 | net_.forward = ca_forward(net_, place_in_unet) 274 | return count + 1 275 | elif hasattr(net_, 'children'): 276 | for net__ in net_.children(): 277 | count = register_recr(net__, count, place_in_unet) 278 | return count 279 | 280 | cross_att_count = 0 281 | # sub_nets = model.unet.named_children() 282 | # we take unet as the input model 283 | sub_nets = model.named_children() 284 | for net in sub_nets: 285 | if "down" in net[0]: 286 | cross_att_count += register_recr(net[1], 0, "down") 287 | elif "up" in net[0]: 288 | cross_att_count += register_recr(net[1], 0, "up") 289 | elif "mid" in net[0]: 290 | cross_att_count += register_recr(net[1], 0, "mid") 291 | 292 | controller.num_att_layers = cross_att_count 293 | 294 | 295 | def get_word_inds(text: str, word_place: int, tokenizer): 296 | split_text = text.split(" ") 297 | if type(word_place) is str: 298 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 299 | elif type(word_place) is int: 300 | word_place = [word_place] 301 | out = [] 302 | if len(word_place) > 0: 303 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 304 | cur_len, ptr = 0, 0 305 | 306 | for i in range(len(words_encode)): 307 | cur_len += len(words_encode[i]) 308 | if ptr in word_place: 309 | out.append(i + 1) 310 | if cur_len >= len(split_text[ptr]): 311 | ptr += 1 312 | cur_len = 0 313 | return np.array(out) 314 | 315 | 316 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 317 | word_inds: Optional[torch.Tensor]=None): 318 | if type(bounds) is float: 319 | bounds = 0, bounds 320 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 321 | if word_inds is None: 322 | word_inds = torch.arange(alpha.shape[2]) 323 | alpha[: start, prompt_ind, word_inds] = 0 324 | alpha[start: end, prompt_ind, word_inds] = 1 325 | alpha[end:, prompt_ind, word_inds] = 0 326 | return alpha 327 | 328 | 329 | def get_time_words_attention_alpha(prompts, num_steps, 330 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 331 | tokenizer, max_num_words=77): 332 | if type(cross_replace_steps) is not dict: 333 | cross_replace_steps = {"default_": cross_replace_steps} 334 | if "default_" not in cross_replace_steps: 335 | cross_replace_steps["default_"] = (0., 1.) 336 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 337 | for i in range(len(prompts) - 1): 338 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 339 | i) 340 | for key, item in cross_replace_steps.items(): 341 | if key != "default_": 342 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 343 | for i, ind in enumerate(inds): 344 | if len(ind) > 0: 345 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 346 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 347 | return alpha_time_words 348 | -------------------------------------------------------------------------------- /vid2vid_zero/p2p/seq_aligner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | 18 | 19 | class ScoreParams: 20 | 21 | def __init__(self, gap, match, mismatch): 22 | self.gap = gap 23 | self.match = match 24 | self.mismatch = mismatch 25 | 26 | def mis_match_char(self, x, y): 27 | if x != y: 28 | return self.mismatch 29 | else: 30 | return self.match 31 | 32 | 33 | def get_matrix(size_x, size_y, gap): 34 | matrix = [] 35 | for i in range(len(size_x) + 1): 36 | sub_matrix = [] 37 | for j in range(len(size_y) + 1): 38 | sub_matrix.append(0) 39 | matrix.append(sub_matrix) 40 | for j in range(1, len(size_y) + 1): 41 | matrix[0][j] = j*gap 42 | for i in range(1, len(size_x) + 1): 43 | matrix[i][0] = i*gap 44 | return matrix 45 | 46 | 47 | def get_matrix(size_x, size_y, gap): 48 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 49 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 50 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 51 | return matrix 52 | 53 | 54 | def get_traceback_matrix(size_x, size_y): 55 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) 56 | matrix[0, 1:] = 1 57 | matrix[1:, 0] = 2 58 | matrix[0, 0] = 4 59 | return matrix 60 | 61 | 62 | def global_align(x, y, score): 63 | matrix = get_matrix(len(x), len(y), score.gap) 64 | trace_back = get_traceback_matrix(len(x), len(y)) 65 | for i in range(1, len(x) + 1): 66 | for j in range(1, len(y) + 1): 67 | left = matrix[i, j - 1] + score.gap 68 | up = matrix[i - 1, j] + score.gap 69 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 70 | matrix[i, j] = max(left, up, diag) 71 | if matrix[i, j] == left: 72 | trace_back[i, j] = 1 73 | elif matrix[i, j] == up: 74 | trace_back[i, j] = 2 75 | else: 76 | trace_back[i, j] = 3 77 | return matrix, trace_back 78 | 79 | 80 | def get_aligned_sequences(x, y, trace_back): 81 | x_seq = [] 82 | y_seq = [] 83 | i = len(x) 84 | j = len(y) 85 | mapper_y_to_x = [] 86 | while i > 0 or j > 0: 87 | if trace_back[i, j] == 3: 88 | x_seq.append(x[i-1]) 89 | y_seq.append(y[j-1]) 90 | i = i-1 91 | j = j-1 92 | mapper_y_to_x.append((j, i)) 93 | elif trace_back[i][j] == 1: 94 | x_seq.append('-') 95 | y_seq.append(y[j-1]) 96 | j = j-1 97 | mapper_y_to_x.append((j, -1)) 98 | elif trace_back[i][j] == 2: 99 | x_seq.append(x[i-1]) 100 | y_seq.append('-') 101 | i = i-1 102 | elif trace_back[i][j] == 4: 103 | break 104 | mapper_y_to_x.reverse() 105 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 106 | 107 | 108 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 109 | x_seq = tokenizer.encode(x) 110 | y_seq = tokenizer.encode(y) 111 | score = ScoreParams(0, 1, -1) 112 | matrix, trace_back = global_align(x_seq, y_seq, score) 113 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 114 | alphas = torch.ones(max_len) 115 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 116 | mapper = torch.zeros(max_len, dtype=torch.int64) 117 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1] 118 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) 119 | return mapper, alphas 120 | 121 | 122 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 123 | x_seq = prompts[0] 124 | mappers, alphas = [], [] 125 | for i in range(1, len(prompts)): 126 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 127 | mappers.append(mapper) 128 | alphas.append(alpha) 129 | return torch.stack(mappers), torch.stack(alphas) 130 | 131 | 132 | def get_word_inds(text: str, word_place: int, tokenizer): 133 | split_text = text.split(" ") 134 | if type(word_place) is str: 135 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 136 | elif type(word_place) is int: 137 | word_place = [word_place] 138 | out = [] 139 | if len(word_place) > 0: 140 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 141 | cur_len, ptr = 0, 0 142 | 143 | for i in range(len(words_encode)): 144 | cur_len += len(words_encode[i]) 145 | if ptr in word_place: 146 | out.append(i + 1) 147 | if cur_len >= len(split_text[ptr]): 148 | ptr += 1 149 | cur_len = 0 150 | return np.array(out) 151 | 152 | 153 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 154 | words_x = x.split(' ') 155 | words_y = y.split(' ') 156 | if len(words_x) != len(words_y): 157 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" 158 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") 159 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 160 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 161 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 162 | mapper = np.zeros((max_len, max_len)) 163 | i = j = 0 164 | cur_inds = 0 165 | while i < max_len and j < max_len: 166 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 167 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 168 | if len(inds_source_) == len(inds_target_): 169 | mapper[inds_source_, inds_target_] = 1 170 | else: 171 | ratio = 1 / len(inds_target_) 172 | for i_t in inds_target_: 173 | mapper[inds_source_, i_t] = ratio 174 | cur_inds += 1 175 | i += len(inds_source_) 176 | j += len(inds_target_) 177 | elif cur_inds < len(inds_source): 178 | mapper[i, j] = 1 179 | i += 1 180 | j += 1 181 | else: 182 | mapper[j, j] = 1 183 | i += 1 184 | j += 1 185 | 186 | return torch.from_numpy(mapper).float() 187 | 188 | 189 | 190 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 191 | x_seq = prompts[0] 192 | mappers = [] 193 | for i in range(1, len(prompts)): 194 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 195 | mappers.append(mapper) 196 | return torch.stack(mappers) 197 | 198 | -------------------------------------------------------------------------------- /vid2vid_zero/pipelines/pipeline_vid2vid_zero.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Callable, List, Optional, Union 17 | from dataclasses import dataclass 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers.utils import is_accelerate_available 23 | from packaging import version 24 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 25 | 26 | from diffusers.configuration_utils import FrozenDict 27 | from diffusers.models import AutoencoderKL # UNet2DConditionModel 28 | from diffusers.pipeline_utils import DiffusionPipeline 29 | from diffusers.schedulers import ( 30 | DDIMScheduler, 31 | DPMSolverMultistepScheduler, 32 | EulerAncestralDiscreteScheduler, 33 | EulerDiscreteScheduler, 34 | LMSDiscreteScheduler, 35 | PNDMScheduler, 36 | ) 37 | from diffusers.utils import deprecate, logging, BaseOutput 38 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 39 | 40 | from einops import rearrange 41 | 42 | from ..models.unet_2d_condition import UNet2DConditionModel 43 | 44 | 45 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 46 | 47 | 48 | @dataclass 49 | class Vid2VidZeroPipelineOutput(BaseOutput): 50 | images: Union[torch.Tensor, np.ndarray] 51 | 52 | 53 | class Vid2VidZeroPipeline(DiffusionPipeline): 54 | r""" 55 | Pipeline for text-to-image generation using Stable Diffusion. 56 | 57 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 58 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 59 | 60 | Args: 61 | vae ([`AutoencoderKL`]): 62 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 63 | text_encoder ([`CLIPTextModel`]): 64 | Frozen text-encoder. Stable Diffusion uses the text portion of 65 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 66 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 67 | tokenizer (`CLIPTokenizer`): 68 | Tokenizer of class 69 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 70 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 71 | scheduler ([`SchedulerMixin`]): 72 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 73 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 74 | safety_checker ([`StableDiffusionSafetyChecker`]): 75 | Classification module that estimates whether generated images could be considered offensive or harmful. 76 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 77 | feature_extractor ([`CLIPFeatureExtractor`]): 78 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 79 | """ 80 | _optional_components = ["safety_checker", "feature_extractor"] 81 | 82 | def __init__( 83 | self, 84 | vae: AutoencoderKL, 85 | text_encoder: CLIPTextModel, 86 | tokenizer: CLIPTokenizer, 87 | unet: UNet2DConditionModel, 88 | scheduler: Union[ 89 | DDIMScheduler, 90 | PNDMScheduler, 91 | LMSDiscreteScheduler, 92 | EulerDiscreteScheduler, 93 | EulerAncestralDiscreteScheduler, 94 | DPMSolverMultistepScheduler, 95 | ], 96 | safety_checker: StableDiffusionSafetyChecker, 97 | feature_extractor: CLIPFeatureExtractor, 98 | requires_safety_checker: bool = False, 99 | ): 100 | super().__init__() 101 | 102 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 103 | deprecation_message = ( 104 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 105 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 106 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 107 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 108 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 109 | " file" 110 | ) 111 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 112 | new_config = dict(scheduler.config) 113 | new_config["steps_offset"] = 1 114 | scheduler._internal_dict = FrozenDict(new_config) 115 | 116 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 117 | deprecation_message = ( 118 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 119 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 120 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 121 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 122 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 123 | ) 124 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 125 | new_config = dict(scheduler.config) 126 | new_config["clip_sample"] = False 127 | scheduler._internal_dict = FrozenDict(new_config) 128 | 129 | if safety_checker is None and requires_safety_checker: 130 | logger.warning( 131 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 132 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 133 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 134 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 135 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 136 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 137 | ) 138 | 139 | if safety_checker is not None and feature_extractor is None: 140 | raise ValueError( 141 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 142 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 143 | ) 144 | 145 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 146 | version.parse(unet.config._diffusers_version).base_version 147 | ) < version.parse("0.9.0.dev0") 148 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 149 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 150 | deprecation_message = ( 151 | "The configuration file of the unet has set the default `sample_size` to smaller than" 152 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 153 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 154 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 155 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 156 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 157 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 158 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 159 | " the `unet/config.json` file" 160 | ) 161 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 162 | new_config = dict(unet.config) 163 | new_config["sample_size"] = 64 164 | unet._internal_dict = FrozenDict(new_config) 165 | 166 | self.register_modules( 167 | vae=vae, 168 | text_encoder=text_encoder, 169 | tokenizer=tokenizer, 170 | unet=unet, 171 | scheduler=scheduler, 172 | safety_checker=safety_checker, 173 | feature_extractor=feature_extractor, 174 | ) 175 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 176 | self.register_to_config(requires_safety_checker=requires_safety_checker) 177 | 178 | def enable_vae_slicing(self): 179 | r""" 180 | Enable sliced VAE decoding. 181 | 182 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several 183 | steps. This is useful to save some memory and allow larger batch sizes. 184 | """ 185 | self.vae.enable_slicing() 186 | 187 | def disable_vae_slicing(self): 188 | r""" 189 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to 190 | computing decoding in one step. 191 | """ 192 | self.vae.disable_slicing() 193 | 194 | def enable_sequential_cpu_offload(self, gpu_id=0): 195 | r""" 196 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, 197 | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a 198 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. 199 | """ 200 | if is_accelerate_available(): 201 | from accelerate import cpu_offload 202 | else: 203 | raise ImportError("Please install accelerate via `pip install accelerate`") 204 | 205 | device = torch.device(f"cuda:{gpu_id}") 206 | 207 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 208 | if cpu_offloaded_model is not None: 209 | cpu_offload(cpu_offloaded_model, device) 210 | 211 | if self.safety_checker is not None: 212 | # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate 213 | # fix by only offloading self.safety_checker for now 214 | cpu_offload(self.safety_checker.vision_model, device) 215 | 216 | @property 217 | def _execution_device(self): 218 | r""" 219 | Returns the device on which the pipeline's models will be executed. After calling 220 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 221 | hooks. 222 | """ 223 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 224 | return self.device 225 | for module in self.unet.modules(): 226 | if ( 227 | hasattr(module, "_hf_hook") 228 | and hasattr(module._hf_hook, "execution_device") 229 | and module._hf_hook.execution_device is not None 230 | ): 231 | return torch.device(module._hf_hook.execution_device) 232 | return self.device 233 | 234 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, uncond_embeddings=None): 235 | r""" 236 | Encodes the prompt into text encoder hidden states. 237 | 238 | Args: 239 | prompt (`str` or `list(int)`): 240 | prompt to be encoded 241 | device: (`torch.device`): 242 | torch device 243 | num_images_per_prompt (`int`): 244 | number of images that should be generated per prompt 245 | do_classifier_free_guidance (`bool`): 246 | whether to use classifier free guidance or not 247 | negative_prompt (`str` or `List[str]`): 248 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 249 | if `guidance_scale` is less than `1`). 250 | """ 251 | batch_size = len(prompt) if isinstance(prompt, list) else 1 252 | 253 | text_inputs = self.tokenizer( 254 | prompt, 255 | padding="max_length", 256 | max_length=self.tokenizer.model_max_length, 257 | truncation=True, 258 | return_tensors="pt", 259 | ) 260 | text_input_ids = text_inputs.input_ids 261 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 262 | 263 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 264 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 265 | logger.warning( 266 | "The following part of your input was truncated because CLIP can only handle sequences up to" 267 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 268 | ) 269 | 270 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 271 | attention_mask = text_inputs.attention_mask.to(device) 272 | else: 273 | attention_mask = None 274 | 275 | text_embeddings = self.text_encoder( 276 | text_input_ids.to(device), 277 | attention_mask=attention_mask, 278 | ) 279 | text_embeddings = text_embeddings[0] 280 | 281 | # duplicate text embeddings for each generation per prompt, using mps friendly method 282 | # num_videos_per_prompt = 1, thus nothing happens here 283 | bs_embed, seq_len, _ = text_embeddings.shape 284 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 285 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 286 | 287 | # get unconditional embeddings for classifier free guidance 288 | if do_classifier_free_guidance: 289 | uncond_tokens: List[str] 290 | if negative_prompt is None: 291 | uncond_tokens = [""] * batch_size 292 | elif type(prompt) is not type(negative_prompt): 293 | raise TypeError( 294 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 295 | f" {type(prompt)}." 296 | ) 297 | elif isinstance(negative_prompt, str): 298 | uncond_tokens = [negative_prompt] 299 | elif batch_size != len(negative_prompt): 300 | raise ValueError( 301 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 302 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 303 | " the batch size of `prompt`." 304 | ) 305 | else: 306 | uncond_tokens = negative_prompt 307 | 308 | max_length = text_input_ids.shape[-1] 309 | uncond_input = self.tokenizer( 310 | uncond_tokens, 311 | padding="max_length", 312 | max_length=max_length, 313 | truncation=True, 314 | return_tensors="pt", 315 | ) 316 | 317 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 318 | attention_mask = uncond_input.attention_mask.to(device) 319 | else: 320 | attention_mask = None 321 | 322 | uncond_embeddings = self.text_encoder( 323 | uncond_input.input_ids.to(device), 324 | attention_mask=attention_mask, 325 | ) 326 | uncond_embeddings = uncond_embeddings[0] 327 | 328 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 329 | seq_len = uncond_embeddings.shape[1] 330 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 331 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 332 | 333 | # For classifier free guidance, we need to do two forward passes. 334 | # Here we concatenate the unconditional and text embeddings into a single batch 335 | # to avoid doing two forward passes 336 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 337 | 338 | return text_embeddings 339 | 340 | def run_safety_checker(self, image, device, dtype): 341 | if self.safety_checker is not None: 342 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) 343 | image, has_nsfw_concept = self.safety_checker( 344 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 345 | ) 346 | else: 347 | has_nsfw_concept = None 348 | return image, has_nsfw_concept 349 | 350 | def decode_latents(self, latents): 351 | video_length = latents.shape[2] 352 | latents = 1 / 0.18215 * latents 353 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 354 | video = self.vae.decode(latents).sample 355 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 356 | video = (video / 2 + 0.5).clamp(0, 1) 357 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 358 | video = video.cpu().float().numpy() 359 | return video 360 | 361 | def prepare_extra_step_kwargs(self, generator, eta): 362 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 363 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 364 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 365 | # and should be between [0, 1] 366 | 367 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 368 | extra_step_kwargs = {} 369 | if accepts_eta: 370 | extra_step_kwargs["eta"] = eta 371 | 372 | # check if the scheduler accepts generator 373 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 374 | if accepts_generator: 375 | extra_step_kwargs["generator"] = generator 376 | return extra_step_kwargs 377 | 378 | def check_inputs(self, prompt, height, width, callback_steps): 379 | if not isinstance(prompt, str) and not isinstance(prompt, list): 380 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 381 | 382 | if height % 8 != 0 or width % 8 != 0: 383 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 384 | 385 | if (callback_steps is None) or ( 386 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 387 | ): 388 | raise ValueError( 389 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 390 | f" {type(callback_steps)}." 391 | ) 392 | 393 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): 394 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) 395 | if isinstance(generator, list) and len(generator) != batch_size: 396 | raise ValueError( 397 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 398 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 399 | ) 400 | 401 | if latents is None: 402 | rand_device = "cpu" if device.type == "mps" else device 403 | 404 | if isinstance(generator, list): 405 | shape = (1,) + shape[1:] 406 | latents = [ 407 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 408 | for i in range(batch_size) 409 | ] 410 | latents = torch.cat(latents, dim=0).to(device) 411 | else: 412 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 413 | else: 414 | if latents.shape != shape: 415 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 416 | latents = latents.to(device) 417 | 418 | # scale the initial noise by the standard deviation required by the scheduler 419 | latents = latents * self.scheduler.init_noise_sigma 420 | return latents 421 | 422 | @torch.no_grad() 423 | def __call__( 424 | self, 425 | prompt: Union[str, List[str]], 426 | video_length: Optional[int], 427 | height: Optional[int] = None, 428 | width: Optional[int] = None, 429 | num_inference_steps: int = 50, 430 | guidance_scale: float = 7.5, 431 | negative_prompt: Optional[Union[str, List[str]]] = None, 432 | num_videos_per_prompt: Optional[int] = 1, 433 | eta: float = 0.0, 434 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 435 | latents: Optional[torch.FloatTensor] = None, 436 | output_type: Optional[str] = "tensor", 437 | return_dict: bool = True, 438 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 439 | callback_steps: Optional[int] = 1, 440 | uncond_embeddings: torch.Tensor = None, 441 | null_uncond_ratio: float = 1.0, 442 | **kwargs, 443 | ): 444 | # Default height and width to unet 445 | height = height or self.unet.config.sample_size * self.vae_scale_factor 446 | width = width or self.unet.config.sample_size * self.vae_scale_factor 447 | 448 | # Check inputs. Raise error if not correct 449 | self.check_inputs(prompt, height, width, callback_steps) 450 | 451 | # Define call parameters 452 | batch_size = 1 if isinstance(prompt, str) else len(prompt) 453 | device = self._execution_device 454 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 455 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 456 | # corresponds to doing no classifier free guidance. 457 | do_classifier_free_guidance = guidance_scale > 1.0 458 | 459 | # Encode input prompt 460 | with_uncond_embedding = do_classifier_free_guidance if uncond_embeddings is None else False 461 | text_embeddings = self._encode_prompt( 462 | prompt, device, num_videos_per_prompt, with_uncond_embedding, negative_prompt, 463 | ) 464 | 465 | # Prepare timesteps 466 | self.scheduler.set_timesteps(num_inference_steps, device=device) 467 | timesteps = self.scheduler.timesteps 468 | 469 | # Prepare latent variables 470 | num_channels_latents = self.unet.in_channels 471 | latents = self.prepare_latents( 472 | batch_size * num_videos_per_prompt, 473 | num_channels_latents, 474 | video_length, 475 | height, 476 | width, 477 | text_embeddings.dtype, 478 | device, 479 | generator, 480 | latents, 481 | ) 482 | latents_dtype = latents.dtype 483 | 484 | # Prepare extra step kwargs. 485 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 486 | 487 | # Denoising loop 488 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 489 | with self.progress_bar(total=num_inference_steps) as progress_bar: 490 | if uncond_embeddings is not None: 491 | start_time = 50 492 | assert (timesteps[-start_time:] == timesteps).all() 493 | for i, t in enumerate(timesteps): 494 | # expand the latents if we are doing classifier free guidance 495 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 496 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 497 | 498 | if uncond_embeddings is not None: 499 | use_uncond_this_step = True 500 | if null_uncond_ratio > 0: 501 | if i > len(timesteps) * null_uncond_ratio: 502 | use_uncond_this_step = False 503 | else: 504 | if i < len(timesteps) * (1 + null_uncond_ratio): 505 | use_uncond_this_step = False 506 | if use_uncond_this_step: 507 | text_embeddings_input = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]) 508 | else: 509 | uncond_embeddings_ = self._encode_prompt('', device, num_videos_per_prompt, False, negative_prompt) 510 | text_embeddings_input = torch.cat([uncond_embeddings_.expand(*text_embeddings.shape), text_embeddings]) 511 | else: 512 | text_embeddings_input = text_embeddings 513 | 514 | # predict the noise residual 515 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings_input).sample.to(dtype=latents_dtype) 516 | 517 | # perform guidance 518 | if do_classifier_free_guidance: 519 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 520 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 521 | 522 | # compute the previous noisy sample x_t -> x_t-1 523 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 524 | 525 | # call the callback, if provided 526 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 527 | progress_bar.update() 528 | if callback is not None and i % callback_steps == 0: 529 | callback(i, t, latents) 530 | 531 | # Post-processing 532 | images = self.decode_latents(latents) 533 | 534 | # Convert to tensor 535 | if output_type == "tensor": 536 | images = torch.from_numpy(images) 537 | 538 | if not return_dict: 539 | return images 540 | 541 | return Vid2VidZeroPipelineOutput(images=images) 542 | -------------------------------------------------------------------------------- /vid2vid_zero/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import tempfile 4 | import numpy as np 5 | from PIL import Image 6 | from typing import Union 7 | 8 | import torch 9 | import torchvision 10 | 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | 14 | 15 | def save_videos_as_images(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=1): 16 | dir_name = os.path.dirname(path) 17 | videos = rearrange(videos, "b c t h w -> t b h w c") 18 | 19 | os.makedirs(os.path.join(dir_name, "vis_images"), exist_ok=True) 20 | for frame_idx, x in enumerate(videos): 21 | if rescale: 22 | x = (x + 1.0) / 2.0 23 | x = (x * 255).numpy().astype(np.uint8) 24 | 25 | for batch_idx, image in enumerate(x): 26 | save_dir = os.path.join(dir_name, "vis_images", f"batch_{batch_idx}") 27 | os.makedirs(save_dir, exist_ok=True) 28 | save_path = os.path.join(save_dir, f"frame_{frame_idx}.png") 29 | image = Image.fromarray(image) 30 | image.save(save_path) 31 | 32 | 33 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=1): 34 | videos = rearrange(videos, "b c t h w -> t b c h w") 35 | outputs = [] 36 | for x in videos: 37 | x = torchvision.utils.make_grid(x, nrow=n_rows) 38 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 39 | if rescale: 40 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 41 | x = (x * 255).numpy().astype(np.uint8) 42 | outputs.append(x) 43 | 44 | os.makedirs(os.path.dirname(path), exist_ok=True) 45 | imageio.mimsave(path, outputs, fps=8) 46 | 47 | # save for gradio demo 48 | out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) 49 | out_file.name = path.replace('.gif', '.mp4') 50 | writer = imageio.get_writer(out_file.name, fps=fps) 51 | for frame in outputs: 52 | writer.append_data(frame) 53 | writer.close() 54 | 55 | 56 | @torch.no_grad() 57 | def init_prompt(prompt, pipeline): 58 | uncond_input = pipeline.tokenizer( 59 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 60 | return_tensors="pt" 61 | ) 62 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 63 | text_input = pipeline.tokenizer( 64 | [prompt], 65 | padding="max_length", 66 | max_length=pipeline.tokenizer.model_max_length, 67 | truncation=True, 68 | return_tensors="pt", 69 | ) 70 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 71 | context = torch.cat([uncond_embeddings, text_embeddings]) 72 | 73 | return context 74 | 75 | 76 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 77 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 78 | timestep, next_timestep = min( 79 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 80 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 81 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 82 | beta_prod_t = 1 - alpha_prod_t 83 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 84 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 85 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 86 | return next_sample 87 | 88 | 89 | def get_noise_pred_single(latents, t, context, unet, normal_infer=False): 90 | bs = latents.shape[0] # (b*f, c, h, w) or (b, c, f, h, w) 91 | if bs != context.shape[0]: 92 | context = context.repeat(bs, 1, 1) # (b*f, len, dim) 93 | noise_pred = unet(latents, t, encoder_hidden_states=context, normal_infer=normal_infer)["sample"] 94 | return noise_pred 95 | 96 | 97 | @torch.no_grad() 98 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt, normal_infer=False): 99 | context = init_prompt(prompt, pipeline) 100 | uncond_embeddings, cond_embeddings = context.chunk(2) 101 | all_latent = [latent] 102 | latent = latent.clone().detach() 103 | for i in tqdm(range(num_inv_steps)): 104 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 105 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet, normal_infer=normal_infer) 106 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 107 | all_latent.append(latent) 108 | return all_latent 109 | 110 | 111 | @torch.no_grad() 112 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt="", normal_infer=False): 113 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt, normal_infer=normal_infer) 114 | return ddim_latents 115 | --------------------------------------------------------------------------------