├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── app_svd.py ├── colab.ipynb ├── compress_video.py ├── docs ├── 4_sr.mp4 ├── barbie2.mp4 ├── fish.gif ├── fish.jpg ├── fish_mask.png ├── framework.png ├── girl5.mp4 ├── labelme.png ├── pig0.mp4 ├── qingming2.gif ├── qingming2_label.jpg ├── sample_1.gif ├── sample_1.png ├── sample_2.gif ├── sample_2.png ├── sample_3.gif └── sample_3.png ├── example ├── barbie.jpg ├── barbie2.jpg ├── deepspeed.yaml ├── example_padded_rgba_pngs │ ├── apple.png │ ├── put rgba images here for train_transparent_i2v_stage2.py.txt │ └── ziyan0.png ├── example_rgba_video_results │ ├── animated rgba results for our transparent unet.txt │ ├── apple │ │ ├── decoded_alpha.webp │ │ └── decoded_rgba.webp │ └── ziyan0 │ │ ├── decoded_alpha.webp │ │ └── decoded_rgba.webp ├── fish1.jpg ├── fish1_label.jpg ├── girl5.jpg ├── hulu2.jpg ├── hulu3.jpg ├── layerdiffuse_stage2_384.yaml ├── pig0.jpg ├── pig0_label.jpg ├── qingming2.jpg ├── qingming2_label.jpg ├── train_mask_motion.yaml ├── train_mask_motion_lora.yaml ├── train_svd.yaml ├── train_svd_mask.yaml ├── train_svd_v2v.yaml └── validation_file.json ├── models ├── layerdiffuse_VAE.py ├── pipeline.py ├── pipeline_stage2.py ├── unet_3d_blocks.py └── unet_3d_condition_mask.py ├── requirements.txt ├── run.sh ├── stable_lora └── lora.py ├── svd_video2video_examples ├── barbie_input.mp4 ├── barbie_mask.png ├── barbie_output.mp4 ├── car_input.mp4 ├── car_mask_1.png ├── car_mask_2.png ├── car_output_1.mp4 ├── car_output_2.mp4 ├── windmill_input.mp4 ├── windmill_mask.png └── windmill_output.mp4 ├── train.py ├── train_lora.py ├── train_svd.py ├── train_transparent_i2v_stage2.py └── utils ├── __init__.py ├── bucketing.py ├── common.py ├── convert_diffusers_to_original_ms_text_to_video.py ├── dataset.py ├── lama.py ├── lora.py ├── lora_handler.py ├── ptp_utils.py └── seq_aligner.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | models/lama.ckpt 3 | .vscode/ 4 | models/model_scope_diffusers/ 5 | text-to-video-ms-1.7b/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | configs 137 | output 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alibaba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

👉 AnimateAnything: Fine Grained Open Domain Image Animation with Motion Guidance

3 | 4 | [Zuozhuo Dai](), [Zhenghao Zhang](), [Menghao Li](), [Junchao Liao](), [Siyu Zhu](), [Long Qin](), [Weizhi Wang]() 5 | 6 | 7 | ![views](https://visitor-badge.laobi.icu/badge?page_id=alibaba.animate-anything&left_color=gray&right_color=red) 8 | 9 |
10 | 11 | ## Friendship Link 🔥 12 | - We are excited to announce the open-source release of our latest work: [Tora: Trajectory-oriented Diffusion Transformer for Video Generation](https://github.com/alibaba/Tora). It is the first trajectory-oriented DiT framework that concurrently integrates textual, visual, and trajectory conditions for video generation. 13 | 14 | ## Showcases 15 | 16 | https://github.com/alibaba/animate-anything/assets/1107525/e2659674-c813-402a-8a85-e620f0a6a454 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 28 | 29 | 30 | 31 | 32 | 34 | 35 | 36 | 37 | 38 | 40 | 41 | 42 |
Input Image with MaskPromptResult
Input image 27 | Barbie watching the camera with a smiling face.Result
Input image 33 | The cloak swaying in the wind.Result
Input image 39 | A red fish is swimming.Result
43 | 44 | 47 | 48 | ## Framework 49 | ![framework](docs/framework.png) 50 | 51 | ## News 🔥 52 | **2024.2.5**: Support multiple GPUs training with Accelerator DeepSpeed. Config DeepSpeed zero_stage 2 and offload_optimizer_device cpu, you can do full finetuning animate-anything with 4x16G V100 GPUs and SVD with 4x24G A10 GPUs now. 53 | 54 | **2023.12.27**: Support finetuning based on SVD (stable video diffusion) model. Update SVD based animate_anything_svd_v1.0 55 | 56 | **2023.12.18**: Update model to animate_anything_512_v1.02 57 | 58 | ## Features Planned 59 | - 💥 Transparent video generatinon. (Take a RGBA image as input and output animated RGBA videos) 60 | - ✅ reproduce Transparent VAE encoder and decoder according to [LayerDiffuse](https://github.com/layerdiffusion/sd-forge-layerdiffuse). 61 | - ✅ finetune 3D-Unet to support the basic RGBA-image-to-RGBA-video capability. 62 | - 💥 Enhanced prompt-following: generating long-detailed captions using LLaVA. 63 | - 💥 Replace the U-Net with DiffusionTransformer (DiT) as the base model. 64 | - 💥 Variable resolutions and aspect ratios. 65 | - 💥 Support Huggingface Demo / Google Colab. 66 | - ✅ support svd video2video Google Colab demo. See colab.ipynb. 67 | - ✅ Support LoRA finetuning. 68 | - etc. 69 | 70 | ## Getting Started 71 | This repository is based on [Text-To-Video-Finetuning](https://github.com/ExponentialML/Text-To-Video-Finetuning.git). 72 | 73 | ### Create Conda Environment (Optional) 74 | It is recommended to install Anaconda. 75 | 76 | **Windows Installation:** https://docs.anaconda.com/anaconda/install/windows/ 77 | 78 | **Linux Installation:** https://docs.anaconda.com/anaconda/install/linux/ 79 | 80 | ```bash 81 | conda create -n animation python=3.10 82 | conda activate animation 83 | ``` 84 | 85 | ### Python Requirements 86 | ```bash 87 | pip install -r requirements.txt 88 | ``` 89 | 90 | ## Running inference 91 | Please download the [pretrained model](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/animate_anything_512_v1.02.tar) to output/latent, then run the following command. Please replace the {download_model} to your download model name: 92 | ```bash 93 | python train.py --config output/latent/{download_model}/config.yaml --eval validation_data.prompt_image=example/barbie2.jpg validation_data.prompt='A cartoon girl is talking.' 94 | ``` 95 | 96 | To control the motion area, we can use the labelme to generate a binary mask. First, we use labelme to draw the polygon for the reference image. 97 | 98 | ![](docs/labelme.png) 99 | 100 | Then we run the following command to transform the labelme json file to a mask. 101 | 102 | ```bash 103 | labelme_json_to_dataset qingming2.json 104 | ``` 105 | ![](docs/qingming2_label.jpg) 106 | 107 | Then run the following command for inference: 108 | ```bash 109 | python train.py --config output/latent/{download_model}/config.yaml --eval validation_data.prompt_image=example/qingming2.jpg validation_data.prompt='Peoples are walking on the street.' validation_data.mask=example/qingming2_label.jpg 110 | ``` 111 | ![](docs/qingming2.gif) 112 | 113 | 114 | User can adjust the motion strength by using the mask motion model: 115 | ```bash 116 | python train.py --config output/latent/{download_model}/ 117 | config.yaml --eval validation_data.prompt_image=example/qingming2.jpg validation_data.prompt='Peoples are walking on the street.' validation_data.mask=example/qingming2_label.jpg validation_data.strength=5 118 | ``` 119 | ## Video super resolution 120 | The model output low res videos, you can use video super resolution model to output high res videos. For example, we can use [Real-CUGAN](https://github.com/bilibili/ailab/tree/main/Real-CUGANfor) cartoon style video super resolution: 121 | 122 | ```bash 123 | git clone https://github.com/bilibili/ailab.git 124 | cd ailab/Real-CUGAN 125 | python inference_video.py 126 | ``` 127 | 128 | ## Training 129 | 130 | ### Using Captions 131 | 132 | You can use caption files when training with video. Simply place the videos into a folder and create a json with captions like this: 133 | 134 | ``` 135 | [ 136 | {"caption": "Cute monster character flat design animation video", "video": "000001_000050/1066697179.mp4"}, 137 | {"caption": "Landscape of the cherry blossom", "video": "000001_000050/1066688836.mp4"} 138 | ] 139 | 140 | ``` 141 | Then in your config, make sure to set dataset_types to video_json and set the video_dir and video json path like this: 142 | ``` 143 | - dataset_types: 144 | - video_json 145 | train_data: 146 | video_dir: '/webvid/webvid/data/videos' 147 | video_json: '/webvid/webvid/data/40K.json' 148 | ``` 149 | ### Process Automatically 150 | 151 | You can automatically caption the videos using the [Video-BLIP2-Preprocessor Script](https://github.com/ExponentialML/Video-BLIP2-Preprocessor) and set the dataset_types and json_path like this: 152 | ``` 153 | - dataset_types: 154 | - video_blip 155 | train_data: 156 | json_path: 'blip_generated.json' 157 | ``` 158 | 159 | ### Configuration 160 | 161 | The configuration uses a YAML config borrowed from [Tune-A-Video](https://github.com/showlab/Tune-A-Video) repositories. 162 | 163 | All configuration details are placed in `example/train_mask_motion.yaml`. Each parameter has a definition for what it does. 164 | 165 | 166 | ### Finetuning anymate-anything 167 | You can finetune anymate-anything with text, motion mask, motion strength guidance on your own dataset. The following config requires around 30G GPU RAM. You can reduce the train_batch_size, train_data.width, train_data.height, and n_sample_frames in the config to reduce GPU RAM: 168 | ``` 169 | python train.py --config example/train_mask_motion.yaml pretrained_model_path= 170 | ``` 171 | 172 | We also support lora finetuning: 173 | ``` 174 | python train_lora.py --config example/train_mask_motion_lora.yaml pretrained_model_path= 175 | ``` 176 | 177 | ### Finetune Stable Video Diffusion: 178 | Stable Video Diffusion (SVD) img2vid model can generate high resolution videos. However, it does not have the text or motion mask control. You can finetune SVD with motioin mask guidance with the following commands and [pretrained SVD model](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/animate_anything_svd_v1.0.tar). This config requires around 80G GPU RAM. 179 | ``` 180 | python train_svd.py --config example/train_svd_mask.yaml pretrained_model_path= 181 | ``` 182 | 183 | If you only want to finetune SVD on your own dataset without motion mask control, please use the following config: 184 | ``` 185 | python train_svd.py --config example/train_svd.yaml pretrained_model_path= 186 | ``` 187 | 188 | ### Multiple GPUs training 189 | I strongly recommend use multiple GPUs training with Accelerator, which will largely decrease the VRAM requirement. Please first config the accelerator with deepspeed. An example config is located in example/deepspeed.yaml. 190 | 191 | And then replace 'python train_xx.py ...' commands above with 'accelerate launch train_xx.py ...', for example: 192 | ``` 193 | accelerate launch --config_file example/deepspeed.yaml train_svd.py --config example/train_svd_mask.yaml pretrained_model_path= 194 | ``` 195 | 196 | ### SVD video2video 197 | We now release the finetuned vid2vid SVD model, you can try it via the gradio UI. 198 | 199 | Please download the [vid2vid_SVD model](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/animate_anything_svd_v1.01.tar) and extract it to output/svd/{download_model} and then run the command: 200 | ``` 201 | python app_svd.py --config example/train_svd_v2v.yaml pretrained_model_path=output/svd/{download_model} 202 | ``` 203 | 204 | We provide several examples in the svd_video2video_examples directory. 205 | 206 | ## Bibtex 207 | Please cite this paper if you find the code is useful for your research: 208 | ``` 209 | @misc{dai2023animateanything, 210 | title={AnimateAnything: Fine-Grained Open Domain Image Animation with Motion Guidance}, 211 | author={Zuozhuo Dai and Zhenghao Zhang and Yao Yao and Bingxue Qiu and Siyu Zhu and Long Qin and Weizhi Wang}, 212 | year={2023}, 213 | eprint={2311.12886}, 214 | archivePrefix={arXiv}, 215 | primaryClass={cs.CV} 216 | } 217 | ``` 218 | ## Shoutouts 219 | 220 | - [Text-To-Video-Finetuning](https://github.com/ExponentialML/Text-To-Video-Finetuning.git) 221 | - [Showlab](https://github.com/showlab/Tune-A-Video) and bryandlee[https://github.com/bryandlee/Tune-A-Video] for their Tune-A-Video contribution that made this much easier. 222 | - [lucidrains](https://github.com/lucidrains) for their implementations around video diffusion. 223 | - [cloneofsimo](https://github.com/cloneofsimo) for their diffusers implementation of LoRA. 224 | - [kabachuha](https://github.com/kabachuha) for their conversion scripts, training ideas, and webui works. 225 | - [JCBrouwer](https://github.com/JCBrouwer) Inference implementations. 226 | - [sergiobr](https://github.com/sergiobr) Helpful ideas and bug fixes. 227 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import random 5 | from argparse import ArgumentParser 6 | from datetime import datetime 7 | import math 8 | 9 | import gradio as gr 10 | import numpy as np 11 | import torch 12 | from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler 13 | from diffusers.image_processor import VaeImageProcessor 14 | from omegaconf import OmegaConf 15 | from PIL import Image 16 | import torchvision.transforms as T 17 | from einops import rearrange, repeat 18 | import imageio 19 | 20 | from models.pipeline import LatentToVideoPipeline 21 | from utils.common import tensor_to_vae_latent, DDPM_forward 22 | 23 | css = """ 24 | .toolbutton { 25 | margin-buttom: 0em 0em 0em 0em; 26 | max-width: 2.5em; 27 | min-width: 2.5em !important; 28 | height: 2.5em; 29 | } 30 | """ 31 | 32 | 33 | class AnimateController: 34 | def __init__(self, pretrained_model_path: str, validation_data, 35 | output_dir, motion_mask = False, motion_strength = False): 36 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 37 | # as these models are only used for inference, keeping weights in full precision is not required. 38 | device=torch.device("cuda") 39 | self.validation_data = validation_data 40 | self.output_dir = output_dir 41 | self.pipeline = LatentToVideoPipeline.from_pretrained(pretrained_model_path, 42 | torch_dtype=torch.float16, variant="fp16").to(device) 43 | self.sample_idx = 0 44 | 45 | def animate( 46 | self, 47 | init_img, 48 | motion_scale, 49 | prompt_textbox, 50 | negative_prompt_textbox, 51 | sample_step_slider, 52 | cfg_scale_slider, 53 | seed_textbox, 54 | style, 55 | progress=gr.Progress(), 56 | ): 57 | 58 | if seed_textbox != "-1" and seed_textbox != "": 59 | torch.manual_seed(int(seed_textbox)) 60 | else: 61 | torch.seed() 62 | seed = torch.initial_seed() 63 | 64 | vae = self.pipeline.vae 65 | diffusion_scheduler = self.pipeline.scheduler 66 | validation_data = self.validation_data 67 | vae_processor = VaeImageProcessor() 68 | 69 | device = vae.device 70 | dtype = vae.dtype 71 | 72 | pimg = Image.fromarray(init_img["background"]).convert('RGB') 73 | width, height = pimg.size 74 | scale = math.sqrt(width*height / (validation_data.height*validation_data.width)) 75 | block_size=8 76 | height = round(height/scale/block_size)*block_size 77 | width = round(width/scale/block_size)*block_size 78 | input_image = vae_processor.preprocess(pimg, height, width) 79 | input_image = input_image.unsqueeze(0).to(dtype).to(device) 80 | input_image_latents = tensor_to_vae_latent(input_image, vae) 81 | np_mask = init_img["layers"][0][:,:,3] 82 | np_mask[np_mask!=0] = 255 83 | if np_mask.sum() == 0: 84 | np_mask[:] = 255 85 | save_sample_path = os.path.join( 86 | self.output_dir, f"{self.sample_idx}.mp4") 87 | out_mask_path = os.path.splitext(save_sample_path)[0] + "_mask.jpg" 88 | Image.fromarray(np_mask).save(out_mask_path) 89 | 90 | b, c, _, h, w = input_image_latents.shape 91 | initial_latents, timesteps = DDPM_forward(input_image_latents, 92 | sample_step_slider, validation_data.num_frames, diffusion_scheduler) 93 | mask = T.ToTensor()(np_mask).to(dtype).to(device) 94 | b, c, f, h, w = initial_latents.shape 95 | mask = T.Resize([h, w], antialias=False)(mask) 96 | mask = rearrange(mask, 'b h w -> b 1 1 h w') 97 | motion_strength = motion_scale * mask.mean().item() 98 | print(f"outfile {save_sample_path}, prompt {prompt_textbox}, motion_strength {motion_strength}") 99 | with torch.no_grad(): 100 | video_frames, video_latents = self.pipeline( 101 | prompt=prompt_textbox, 102 | latents=initial_latents, 103 | width=width, 104 | height=height, 105 | num_frames=validation_data.num_frames, 106 | num_inference_steps=sample_step_slider, 107 | guidance_scale=cfg_scale_slider, 108 | condition_latent=input_image_latents, 109 | mask=mask, 110 | motion=[motion_strength], 111 | return_dict=False, 112 | timesteps=timesteps, 113 | ) 114 | 115 | imageio.mimwrite(save_sample_path, video_frames, fps=8) 116 | self.sample_idx += 1 117 | return save_sample_path 118 | 119 | 120 | def ui(controller): 121 | with gr.Blocks(css=css) as demo: 122 | 123 | gr.HTML( 124 | "
Animate Anything
" 125 | ) 126 | with gr.Row(): 127 | gr.Markdown( 128 | "
Project Page  " # noqa 129 | "Paper  " 130 | "Code  " # noqa 131 | "

Instructions: 1. Upload image 2. Draw mask on image using draw button. 3. Write prompt. 4.Click generate button. If it is not response, please click again.

" 132 | ) 133 | 134 | with gr.Row(equal_height=False): 135 | with gr.Column(): 136 | with gr.Row(): 137 | init_img = gr.ImageMask(label='Input Image', brush=gr.Brush(default_size=100)) 138 | style_dropdown = gr.Dropdown(label='Style', choices=['384', '512']) 139 | with gr.Row(): 140 | prompt_textbox = gr.Textbox(label="Prompt", value='moving', lines=1) 141 | 142 | motion_scale_silder = gr.Slider( 143 | label='Motion Strength (Larger value means larger motion but less identity consistency)', 144 | value=5, step=1, minimum=1, maximum=20) 145 | 146 | with gr.Accordion('Advance Options', open=False): 147 | negative_prompt_textbox = gr.Textbox( 148 | value="", label="Negative prompt", lines=2) 149 | 150 | sample_step_slider = gr.Slider( 151 | label="Sampling steps", value=25, minimum=10, maximum=100, step=1) 152 | 153 | cfg_scale_slider = gr.Slider( 154 | label="CFG Scale", value=9, minimum=0, maximum=20) 155 | 156 | with gr.Row(): 157 | seed_textbox = gr.Textbox(label="Seed", value=-1) 158 | seed_button = gr.Button( 159 | value="\U0001F3B2", elem_classes="toolbutton") 160 | seed_button.click( 161 | fn=lambda x: random.randint(1, 1e8), 162 | outputs=[seed_textbox], 163 | queue=False 164 | ) 165 | 166 | generate_button = gr.Button( 167 | value="Generate", variant='primary') 168 | 169 | result_video = gr.Video( 170 | label="Generated Animation", interactive=False) 171 | 172 | generate_button.click( 173 | fn=controller.animate, 174 | inputs=[ 175 | init_img, 176 | motion_scale_silder, 177 | prompt_textbox, 178 | negative_prompt_textbox, 179 | sample_step_slider, 180 | cfg_scale_slider, 181 | seed_textbox, 182 | style_dropdown, 183 | ], 184 | outputs=[result_video] 185 | ) 186 | 187 | def create_example(input_list): 188 | return gr.Examples( 189 | examples=input_list, 190 | inputs=[ 191 | init_img, 192 | result_video, 193 | prompt_textbox, 194 | style_dropdown, 195 | motion_scale_silder, 196 | ], 197 | ) 198 | 199 | gr.Markdown( 200 | '### Merry Christmas!' 201 | ) 202 | create_example( 203 | [ 204 | [ 'example/pig0.jpg', 'docs/pig0.mp4', 'pigs are talking', '512', 3], 205 | [ 'example/barbie2.jpg', 'docs/barbie2.mp4', 'a girl is talking', '512', 4], 206 | ], 207 | 208 | ) 209 | 210 | return demo 211 | 212 | 213 | if __name__ == "__main__": 214 | parser = ArgumentParser() 215 | parser.add_argument('--config', type=str, default='example/config/base.yaml') 216 | parser.add_argument('--server-name', type=str, default='0.0.0.0') 217 | parser.add_argument('--port', type=int, default=7860) 218 | parser.add_argument('--share', action='store_true', default=False) 219 | parser.add_argument('--local-debug', action='store_true') 220 | parser.add_argument('--save-path', default='samples') 221 | 222 | args, unknownargs = parser.parse_known_args() 223 | LOCAL_DEBUG = args.local_debug 224 | args_dict = OmegaConf.load(args.config) 225 | cli_conf = OmegaConf.from_cli() 226 | args_dict = OmegaConf.merge(args_dict, cli_conf) 227 | controller = AnimateController(args_dict.pretrained_model_path, args_dict.validation_data, 228 | args_dict.output_dir, args_dict.motion_mask, args_dict.motion_strength) 229 | demo = ui(controller) 230 | demo.queue(max_size=10) 231 | demo.launch(server_name=args.server_name, 232 | server_port=args.port, max_threads=40, 233 | allowed_paths=['example/barbie2.jpg'], 234 | share=args.share) 235 | -------------------------------------------------------------------------------- /app_svd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import ArgumentParser 4 | import math 5 | 6 | import gradio as gr 7 | import torch 8 | from diffusers.image_processor import VaeImageProcessor 9 | from omegaconf import OmegaConf 10 | from PIL import Image 11 | import torchvision.transforms as T 12 | import imageio 13 | 14 | from diffusers import StableVideoDiffusionPipeline 15 | from models.pipeline import TextStableVideoDiffusionPipeline 16 | from einops import rearrange, repeat 17 | from utils.common import read_video 18 | 19 | css = """ 20 | .toolbutton { 21 | margin-buttom: 0em 0em 0em 0em; 22 | max-width: 2.5em; 23 | min-width: 2.5em !important; 24 | height: 2.5em; 25 | } 26 | """ 27 | 28 | 29 | class AnimateController: 30 | def __init__(self, pretrained_model_path: str, validation_data, 31 | output_dir, motion_mask = False, motion_strength = False): 32 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 33 | # as these models are only used for inference, keeping weights in full precision is not required. 34 | device=torch.device("cuda") 35 | self.validation_data = validation_data 36 | self.output_dir = output_dir 37 | self.pipeline = StableVideoDiffusionPipeline.from_pretrained(pretrained_model_path, torch_dtype=torch.float16, variant="fp16").to(device) 38 | #self.pipeline = StableVideoDiffusionPipeline.from_pretrained(pretrained_model_path).to(device) 39 | self.sample_idx = 0 40 | 41 | def animate( 42 | self, 43 | init_img, 44 | input_video, 45 | sample_step_slider, 46 | seed_textbox, 47 | fps_textbox, 48 | num_frames_textbox, 49 | motion_bucket_id_slider, 50 | progress=gr.Progress(), 51 | ): 52 | 53 | if seed_textbox != "-1" and seed_textbox != "": 54 | torch.manual_seed(int(seed_textbox)) 55 | else: 56 | torch.seed() 57 | seed = torch.initial_seed() 58 | 59 | with torch.no_grad(): 60 | vae = self.pipeline.vae 61 | validation_data = self.validation_data 62 | validation_data.fps = int(fps_textbox) 63 | validation_data.num_frames = int(num_frames_textbox) 64 | validation_data.motion_bucket_id = int(motion_bucket_id_slider) 65 | vae_processor = VaeImageProcessor() 66 | 67 | device = vae.device 68 | dtype = vae.dtype 69 | 70 | f = validation_data.num_frames 71 | pimg = Image.fromarray(init_img["background"]).convert('RGB') 72 | np_mask = init_img["layers"][0][:,:,3] 73 | np_mask[np_mask!=0] = 255 74 | if np_mask.sum() == 0: 75 | np_mask[:] = 255 76 | if input_video is not None: 77 | frames = read_video(input_video) 78 | frames = [Image.fromarray(f) for f in frames] 79 | pimg = frames[0] 80 | width, height = pimg.size 81 | scale = math.sqrt(width*height / (validation_data.height*validation_data.width)) 82 | block_size=64 83 | height = round(height/scale/block_size)*block_size 84 | width = round(width/scale/block_size)*block_size 85 | f = len(frames) 86 | 87 | latents = [] 88 | for frame in frames: 89 | input_image = vae_processor.preprocess(frame, height, width) 90 | input_image = input_image.to(dtype).to(device) 91 | input_image_latent = vae.encode(input_image).latent_dist.mode() * vae.config.scaling_factor 92 | latents.append(input_image_latent.unsqueeze(1)) 93 | latents = torch.cat(latents, dim=1) 94 | else: 95 | width, height = pimg.size 96 | scale = math.sqrt(width*height / (validation_data.height*validation_data.width)) 97 | block_size=64 98 | height = round(height/scale/block_size)*block_size 99 | width = round(width/scale/block_size)*block_size 100 | input_image = vae_processor.preprocess(pimg, height, width) 101 | input_image = input_image.to(dtype).to(device) 102 | input_image_latent = vae.encode(input_image).latent_dist.mode() * vae.config.scaling_factor 103 | latents = repeat(input_image_latent, 'b c h w->b f c h w', f=f) 104 | 105 | b, f, c, h, w = latents.shape 106 | 107 | mask = T.ToTensor()(np_mask).to(dtype).to(device) 108 | mask = T.Resize([h, w], antialias=False)(mask) 109 | mask = repeat(mask, 'b h w -> b f 1 h w', f=f).detach().clone() 110 | mask[:,0] = 0 111 | freeze = repeat(latents[:,0], 'b c h w -> b f c h w', f=f) 112 | condition_latents = latents * (1-mask) + freeze * mask 113 | condition_latents = condition_latents/vae.config.scaling_factor 114 | 115 | motion_mask = self.pipeline.unet.config.in_channels == 9 116 | decode_chunk_size=validation_data.get("decode_chunk_size", 7) 117 | fps=validation_data.get("fps", 7) 118 | motion_bucket_id=validation_data.get("motion_bucket_id", 127) 119 | if motion_mask: 120 | video_frames = TextStableVideoDiffusionPipeline.__call__( 121 | self.pipeline, 122 | image=pimg, 123 | width=width, 124 | height=height, 125 | num_frames=validation_data.num_frames, 126 | num_inference_steps=validation_data.num_inference_steps, 127 | decode_chunk_size=decode_chunk_size, 128 | fps=fps, 129 | motion_bucket_id=motion_bucket_id, 130 | mask=mask, 131 | condition_type="image", 132 | condition_latent=condition_latents 133 | ).frames[0] 134 | else: 135 | video_frames = self.pipeline( 136 | image=pimg, 137 | width=width, 138 | height=height, 139 | num_frames=validation_data.num_frames, 140 | num_inference_steps=validation_data.num_inference_steps, 141 | fps=validation_data.fps, 142 | decode_chunk_size=validation_data.decode_chunk_size, 143 | motion_bucket_id=validation_data.motion_bucket_id, 144 | ).frames[0] 145 | 146 | save_sample_path = os.path.join( 147 | self.output_dir, f"{self.sample_idx}.mp4") 148 | Image.fromarray(np_mask).save(os.path.join( 149 | self.output_dir, f"{self.sample_idx}_label.jpg")) 150 | imageio.mimwrite(save_sample_path, video_frames, fps=7) 151 | self.sample_idx += 1 152 | return save_sample_path 153 | 154 | import cv2 155 | 156 | def get_video_info(video_path): 157 | cap = cv2.VideoCapture(video_path) 158 | if not cap.isOpened(): 159 | return None 160 | 161 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 162 | cap.release() 163 | 164 | return length 165 | 166 | def update_num_frames(input_video, num_frames_textbox): 167 | frame_count = get_video_info(input_video) 168 | return frame_count or 14 169 | 170 | def ui(controller): 171 | with gr.Blocks(css=css) as demo: 172 | 173 | gr.HTML( 174 | "
Animate Anything For SVD
" 175 | ) 176 | with gr.Row(): 177 | gr.Markdown( 178 | "
Project Page  " # noqa 179 | "Paper  " 180 | "Code  " # noqa 181 | ) 182 | 183 | with gr.Row(equal_height=True): 184 | with gr.Column(): 185 | init_img = gr.ImageMask(label='Input Image', brush=gr.Brush(default_size=100)) 186 | generate_button = gr.Button( 187 | value="Generate", variant='primary') 188 | input_video = gr.Video(label="Input video", interactive=True) 189 | 190 | result_video = gr.Video( 191 | label="Generated Animation", interactive=False) 192 | 193 | with gr.Accordion('Advance Options', open=False): 194 | with gr.Row(): 195 | fps_textbox = gr.Number(label="Fps", value=7, minimum=1) 196 | num_frames_textbox = gr.Number(label="Num frames", value=14, minimum=1, maximum=78) 197 | 198 | input_video.upload( 199 | fn=update_num_frames, 200 | inputs=[input_video], 201 | outputs=[num_frames_textbox] 202 | ) 203 | 204 | motion_bucket_id_slider = gr.Slider( 205 | label='motion_bucket_id', 206 | value=127, step=1, minimum=0, maximum=511) 207 | 208 | sample_step_slider = gr.Slider( 209 | label="Sampling steps", value=25, minimum=10, maximum=100, step=1) 210 | 211 | with gr.Row(): 212 | seed_textbox = gr.Textbox(label="Seed", value=-1) 213 | seed_button = gr.Button( 214 | value="\U0001F3B2", elem_classes="toolbutton") 215 | seed_button.click( 216 | fn=lambda x: random.randint(1, 1e8), 217 | outputs=[seed_textbox], 218 | queue=False 219 | ) 220 | 221 | 222 | 223 | generate_button.click( 224 | fn=controller.animate, 225 | inputs=[ 226 | init_img, 227 | input_video, 228 | sample_step_slider, 229 | seed_textbox, 230 | fps_textbox, 231 | num_frames_textbox, 232 | motion_bucket_id_slider 233 | ], 234 | outputs=[result_video] 235 | ) 236 | 237 | return demo 238 | 239 | 240 | if __name__ == "__main__": 241 | parser = ArgumentParser() 242 | parser.add_argument('--config', type=str, default='example/config/base.yaml') 243 | parser.add_argument('--server-name', type=str, default='0.0.0.0') 244 | parser.add_argument('--port', type=int, default=7860) 245 | parser.add_argument('--share', action='store_true') 246 | parser.add_argument('--local-debug', action='store_true') 247 | parser.add_argument('--save-path', default='samples') 248 | 249 | args, unknownargs = parser.parse_known_args() 250 | LOCAL_DEBUG = args.local_debug 251 | args_dict = OmegaConf.load(args.config) 252 | cli_conf = OmegaConf.from_cli() 253 | args_dict = OmegaConf.merge(args_dict, cli_conf) 254 | controller = AnimateController(args_dict.pretrained_model_path, args_dict.validation_data, 255 | args_dict.output_dir, args_dict.motion_mask, args_dict.motion_strength) 256 | demo = ui(controller) 257 | demo.queue(max_size=10) 258 | demo.launch(server_name=args.server_name, 259 | server_port=args.port, max_threads=40, 260 | ) 261 | -------------------------------------------------------------------------------- /colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true, 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dailingx/animate-anything/blob/main/colab.ipynb)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!git clone https://github.com/alibaba/animate-anything /content/animate-anything\n", 22 | "%cd /content/animate-anything\n", 23 | "\n", 24 | "!pip install -r requirements.txt\n", 25 | "\n", 26 | "!apt -y install -qq aria2\n", 27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/animate_anything_512_v1.02.tar -d output/latent\n", 28 | "!tar -xf output/latent/animate_anything_512_v1.02.tar -C output/latent/\n", 29 | "\n", 30 | "!python app.py --config output/latent/animate_anything_512_v1.02/config.yaml --share" 31 | ], 32 | "metadata": { 33 | "collapsed": false, 34 | "pycharm": { 35 | "name": "#%%\n" 36 | } 37 | } 38 | } 39 | ], 40 | "metadata": { 41 | "accelerator": "GPU", 42 | "colab": { 43 | "gpuType": "T4", 44 | "provenance": [] 45 | }, 46 | "kernelspec": { 47 | "display_name": "Python 3", 48 | "language": "python", 49 | "name": "python3" 50 | }, 51 | "language_info": { 52 | "codemirror_mode": { 53 | "name": "ipython", 54 | "version": 2 55 | }, 56 | "file_extension": ".py", 57 | "mimetype": "text/x-python", 58 | "name": "python", 59 | "nbconvert_exporter": "python", 60 | "pygments_lexer": "ipython2", 61 | "version": "2.7.6" 62 | } 63 | }, 64 | "nbformat": 4, 65 | "nbformat_minor": 0 66 | } 67 | -------------------------------------------------------------------------------- /compress_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used to compress video in: https://github.com/ArrowLuo/CLIP4Clip 3 | Author: ArrowLuo 4 | """ 5 | import os 6 | import argparse 7 | import ffmpeg 8 | import subprocess 9 | import time 10 | import multiprocessing 11 | from multiprocessing import Pool 12 | import shutil 13 | import json 14 | try: 15 | from psutil import cpu_count 16 | except: 17 | from multiprocessing import cpu_count 18 | # multiprocessing.freeze_support() 19 | 20 | def compress(paras): 21 | input_video_path, output_video_path = paras 22 | try: 23 | command = ['ffmpeg', 24 | '-y', # (optional) overwrite output file if it exists 25 | '-i', input_video_path, 26 | '-filter:v', 27 | 'scale=\'if(gt(a,1),trunc(oh*a/2)*2,512)\':\'if(gt(a,1),512,trunc(ow*a/2)*2)\'', # scale to 256 28 | '-map', '0:v', 29 | #'-r', '3', # frames per second 30 | output_video_path, 31 | ] 32 | ffmpeg = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 33 | out, err = ffmpeg.communicate() 34 | retcode = ffmpeg.poll() 35 | # print something above for debug 36 | except Exception as e: 37 | raise e 38 | 39 | def prepare_input_output_pairs(input_root, output_root): 40 | input_video_path_list = [] 41 | output_video_path_list = [] 42 | for root, dirs, files in os.walk(input_root): 43 | for file_name in files: 44 | input_video_path = os.path.join(root, file_name) 45 | output_video_path = os.path.join(output_root, file_name) 46 | output_video_path = os.path.splitext(output_video_path)[0] + ".mp4" 47 | if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0: 48 | pass 49 | else: 50 | input_video_path_list.append(input_video_path) 51 | output_video_path_list.append(output_video_path) 52 | return input_video_path_list, output_video_path_list 53 | 54 | def msvd(): 55 | captions = pickle.load(open('raw-captions.pkl','rb')) 56 | outdir = "/data/datasets/msvd/videos_mp4" 57 | for key in captions: 58 | outpath = os.path.join(outdir, key+".txt") 59 | with open(outpath, 'w') as f: 60 | for line in captions[key]: 61 | f.write(" ".join(line)+"\n") 62 | 63 | def webvid(): 64 | 65 | df = pd.read_csv('/webvid/results_2M_train_1/0.csv') 66 | df['rel_fn'] = df.apply(lambda x: os.path.join(str(x['page_dir']), str(x['videoid'])), axis=1) 67 | 68 | df['rel_fn'] = df['rel_fn'] + '.mp4' 69 | # remove nan 70 | df.dropna(subset=['page_dir'], inplace=True) 71 | 72 | playlists_to_dl = np.sort(df['page_dir'].unique()) 73 | 74 | vjson = [] 75 | video_dir = '/webvid/webvid/data/videos' 76 | for page_dir in playlists_to_dl: 77 | vid_dir_t = os.path.join(video_dir, page_dir) 78 | pdf = df[df['page_dir'] == page_dir] 79 | if len(pdf) > 0: 80 | for idx, row in pdf.iterrows(): 81 | video_fp = os.path.join(vid_dir_t, str(row['videoid']) + '.mp4') 82 | if os.path.isfile(video_fp): 83 | caption = row['name'] 84 | video_path = os.path.join(page_dir, str(row['videoid'])+'.mp4') 85 | vjson.append({'caption':caption,'video':video_path}) 86 | with open('/webvid/webvid/data/2M.json', 'w') as f: 87 | json.dump(vjson, f) 88 | 89 | def webvid20k(): 90 | j = json.load(open('/webvid/webvid/data/2M.json')) 91 | idir = '/webvid/webvid/data/videos' 92 | 93 | v2c = [] 94 | for item in j: 95 | caption = item['caption'] 96 | video = item['video'] 97 | if os.path.exists(os.path.join(idir, video)): 98 | v2c.append(item) 99 | print("video numbers", len(v2c)) 100 | with open('/webvid/webvid/data/40K.json', 'w') as f: 101 | json.dump(v2c, f) 102 | 103 | 104 | if __name__ == "__main__": 105 | parser = argparse.ArgumentParser(description='Compress video for speed-up') 106 | parser.add_argument('--input_root', type=str, help='input root') 107 | parser.add_argument('--output_root', type=str, help='output root') 108 | args = parser.parse_args() 109 | 110 | input_root = args.input_root 111 | output_root = args.output_root 112 | 113 | assert input_root != output_root 114 | 115 | if not os.path.exists(output_root): 116 | os.makedirs(output_root, exist_ok=True) 117 | 118 | input_video_path_list, output_video_path_list = prepare_input_output_pairs(input_root, output_root) 119 | 120 | print("Total video need to process: {}".format(len(input_video_path_list))) 121 | num_works = cpu_count() 122 | print("Begin with {}-core logical processor.".format(num_works)) 123 | 124 | pool = Pool(num_works) 125 | data_dict_list = pool.map(compress, 126 | [(input_video_path, output_video_path) for 127 | input_video_path, output_video_path in 128 | zip(input_video_path_list, output_video_path_list)]) 129 | pool.close() 130 | pool.join() 131 | 132 | print("Compress finished, wait for checking files...") 133 | for input_video_path, output_video_path in zip(input_video_path_list, output_video_path_list): 134 | if os.path.exists(input_video_path): 135 | if os.path.exists(output_video_path) is False or os.path.getsize(output_video_path) < 1.: 136 | print("convert fail: {}".format(output_video_path)) -------------------------------------------------------------------------------- /docs/4_sr.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/4_sr.mp4 -------------------------------------------------------------------------------- /docs/barbie2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/barbie2.mp4 -------------------------------------------------------------------------------- /docs/fish.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/fish.gif -------------------------------------------------------------------------------- /docs/fish.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/fish.jpg -------------------------------------------------------------------------------- /docs/fish_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/fish_mask.png -------------------------------------------------------------------------------- /docs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/framework.png -------------------------------------------------------------------------------- /docs/girl5.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/girl5.mp4 -------------------------------------------------------------------------------- /docs/labelme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/labelme.png -------------------------------------------------------------------------------- /docs/pig0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/pig0.mp4 -------------------------------------------------------------------------------- /docs/qingming2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/qingming2.gif -------------------------------------------------------------------------------- /docs/qingming2_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/qingming2_label.jpg -------------------------------------------------------------------------------- /docs/sample_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_1.gif -------------------------------------------------------------------------------- /docs/sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_1.png -------------------------------------------------------------------------------- /docs/sample_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_2.gif -------------------------------------------------------------------------------- /docs/sample_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_2.png -------------------------------------------------------------------------------- /docs/sample_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_3.gif -------------------------------------------------------------------------------- /docs/sample_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_3.png -------------------------------------------------------------------------------- /example/barbie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/barbie.jpg -------------------------------------------------------------------------------- /example/barbie2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/barbie2.jpg -------------------------------------------------------------------------------- /example/deepspeed.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 1 5 | offload_optimizer_device: cpu 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | gpu_ids: 0,1,2,3 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: fp16 15 | num_machines: 1 16 | num_processes: 4 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /example/example_padded_rgba_pngs/apple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_padded_rgba_pngs/apple.png -------------------------------------------------------------------------------- /example/example_padded_rgba_pngs/put rgba images here for train_transparent_i2v_stage2.py.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_padded_rgba_pngs/put rgba images here for train_transparent_i2v_stage2.py.txt -------------------------------------------------------------------------------- /example/example_padded_rgba_pngs/ziyan0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_padded_rgba_pngs/ziyan0.png -------------------------------------------------------------------------------- /example/example_rgba_video_results/animated rgba results for our transparent unet.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/animated rgba results for our transparent unet.txt -------------------------------------------------------------------------------- /example/example_rgba_video_results/apple/decoded_alpha.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/apple/decoded_alpha.webp -------------------------------------------------------------------------------- /example/example_rgba_video_results/apple/decoded_rgba.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/apple/decoded_rgba.webp -------------------------------------------------------------------------------- /example/example_rgba_video_results/ziyan0/decoded_alpha.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/ziyan0/decoded_alpha.webp -------------------------------------------------------------------------------- /example/example_rgba_video_results/ziyan0/decoded_rgba.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/ziyan0/decoded_rgba.webp -------------------------------------------------------------------------------- /example/fish1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/fish1.jpg -------------------------------------------------------------------------------- /example/fish1_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/fish1_label.jpg -------------------------------------------------------------------------------- /example/girl5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/girl5.jpg -------------------------------------------------------------------------------- /example/hulu2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/hulu2.jpg -------------------------------------------------------------------------------- /example/hulu3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/hulu3.jpg -------------------------------------------------------------------------------- /example/layerdiffuse_stage2_384.yaml: -------------------------------------------------------------------------------- 1 | # Pretrained diffusers model path. 2 | transparent_unet_pretrained_model_path: "./output/latent/transparent_unet" 3 | transparent_VAE_pretrained_model_path: "./output/latent/transparent_VAE" 4 | 5 | motion_mask: True 6 | motion_strength: True 7 | in_channels: 5 # 5 or 9 8 | 9 | # The folder where your training outputs will be placed. 10 | output_dir: "output/stage_2_eval" 11 | 12 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise 13 | # If this is enabled, rescale_schedule will be disabled. 14 | offset_noise_strength: 0.1 15 | use_offset_noise: False 16 | 17 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf 18 | # If this is enabled, offset noise will be disabled. 19 | rescale_schedule: True 20 | 21 | # When True, this extends all items in all enabled datasets to the highest length. 22 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200. 23 | extend_dataset: False 24 | 25 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD. 26 | # The latents will be saved under your training folder, and loaded automatically for training. 27 | # This both saves memory and speeds up training and takes very little disk space. 28 | cache_latents: False 29 | 30 | # If you have cached latents set to `True` and have a directory of cached latents, 31 | # you can skip the caching process and load previously saved ones. 32 | cached_latent_dir: null #/path/to/cached_latents 33 | 34 | # Train the text encoder for the model. LoRA Training overrides this setting. 35 | train_text_encoder: False 36 | 37 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension) 38 | # This is the first, original implementation of LoRA by cloneofsimo. 39 | # Use this version if you want to maintain compatibility to the original version. 40 | 41 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension) 42 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here. 43 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv). 44 | # This means that there isn't an extra function when doing low ranking adaption. 45 | # It solely saves the weight differential between the initialized weights and updates. 46 | 47 | # "cloneofsimo" or "stable_lora" 48 | lora_version: "cloneofsimo" 49 | 50 | # Use LoRA for the UNET model. 51 | use_unet_lora: False 52 | 53 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained. 54 | use_text_lora: False 55 | 56 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting. 57 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html 58 | lora_unet_dropout: 0.1 59 | 60 | lora_text_dropout: 0.1 61 | 62 | # https://github.com/kabachuha/sd-webui-text2video 63 | # This saves a LoRA that is compatible with the text2video webui extension. 64 | # It only works when the lora version is 'stable_lora'. 65 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation. 66 | save_lora_for_webui: True 67 | 68 | # The LoRA file will be converted to a different format to be compatible with the webui extension. 69 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model 70 | # when this version is set to False 71 | only_lora_for_webui: False 72 | 73 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training. 74 | # The only time you want this off is if you're doing full LoRA training. 75 | save_pretrained_model: True 76 | 77 | # The modules to use for LoRA. Different from 'trainable_modules'. 78 | unet_lora_modules: 79 | - "UNet3DConditionModel" 80 | #- "ResnetBlock2D" 81 | #- "TransformerTemporalModel" 82 | #- "Transformer2DModel" 83 | #- "CrossAttention" 84 | #- "Attention" 85 | #- "GEGLU" 86 | #- "TemporalConvLayer" 87 | 88 | # The modules to use for LoRA. Different from `trainable_text_modules`. 89 | text_encoder_lora_modules: 90 | - "CLIPEncoderLayer" 91 | #- "CLIPAttention" 92 | 93 | # The rank for LoRA training. With ModelScope, the maximum should be 1024. 94 | # VRAM increases with higher rank, lower when decreased. 95 | lora_rank: 16 96 | 97 | # You can train multiple datasets at once. They will be joined together for training. 98 | # Simply remove the line you don't need, or keep them all for mixed training. 99 | 100 | # 'image': A folder of images and captions (.txt) 101 | # 'folder': A folder a videos and captions (.txt) 102 | # 'json': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor 103 | # 'video_json': a video foler and a json caption file 104 | # 'single_video': A single video file.mp4 and text prompt 105 | dataset_types: 106 | #- 'single_video' 107 | #- 'folder' 108 | # - 'image' 109 | - 'video_blip' 110 | # - 'video_json' 111 | 112 | # Training data parameters 113 | train_data: 114 | width: 384 115 | height: 384 116 | use_bucketing: False 117 | return_mask: True 118 | return_motion: True 119 | sample_start_idx: 0 120 | fps: 8 121 | n_sample_frames: 8 122 | 123 | json_path: '' 124 | 125 | 126 | # Validation data parameters. 127 | validation_data: 128 | 129 | # A custom prompt that is different from your training dataset. 130 | prompt: "" 131 | 132 | prompt_image: "" 133 | 134 | # Whether or not to sample preview during training (Requires more VRAM). 135 | sample_preview: True 136 | 137 | # The number of frames to sample during validation. 138 | num_frames: 8 139 | 140 | # Height and width of validation sample. 141 | width: 384 142 | height: 384 143 | 144 | # Number of inference steps when generating the video. 145 | num_inference_steps: 25 146 | 147 | # CFG scale 148 | guidance_scale: 9 149 | 150 | # Learning rate for AdamW 151 | learning_rate: 3.0e-05 152 | lr_scheduler: "cosine" 153 | lr_warmup_steps: 20 154 | # Weight decay. Higher = more regularization. Lower = closer to dataset. 155 | adam_weight_decay: 0 156 | 157 | # Optimizer parameters for the UNET. Overrides base learning rate parameters. 158 | extra_unet_params: null 159 | #learning_rate: 1e-5 160 | #adam_weight_decay: 1e-4 161 | 162 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters. 163 | extra_text_encoder_params: null 164 | #learning_rate: 1e-4 165 | #adam_weight_decay: 0.2 166 | 167 | # How many batches to train. Not to be confused with video frames. 168 | train_batch_size: 8 169 | image_batch_size: 48 170 | gradient_accumulation_steps: 4 171 | # Maximum number of train steps. Model is saved after training. 172 | max_train_steps: 2000 173 | 174 | # Saves a model every nth step. 175 | checkpointing_steps: 200 176 | 177 | # How many steps to do for validation if sample_preview is enabled. 178 | validation_steps: 50 179 | 180 | # Which modules we want to unfreeze for the UNET. Advanced usage. 181 | trainable_modules: 182 | - "all" 183 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions" 184 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1 185 | - "attn1" 186 | - ".attentions" 187 | 188 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1 189 | - "attn2" 190 | - "conv_in" 191 | 192 | # Convolution networks that hold temporal information. Activates for spatial and temporal dimensions if n_sample_frames > 1 193 | - "temp_conv" 194 | - "motion" 195 | 196 | 197 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage. 198 | trainable_text_modules: null 199 | 200 | # Seed for validation. 201 | seed: null 202 | 203 | # Whether or not we want to use mixed precision with accelerate 204 | mixed_precision: "fp16" 205 | 206 | # This seems to be incompatible at the moment. 207 | use_8bit_adam: False 208 | 209 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM. 210 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2. 211 | gradient_checkpointing: True 212 | text_encoder_gradient_checkpointing: False 213 | 214 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0) 215 | enable_xformers_memory_efficient_attention: False 216 | 217 | # Use scaled dot product attention (Only available with >= Torch 2.0) 218 | enable_torch_2_attn: True 219 | -------------------------------------------------------------------------------- /example/pig0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/pig0.jpg -------------------------------------------------------------------------------- /example/pig0_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/pig0_label.jpg -------------------------------------------------------------------------------- /example/qingming2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/qingming2.jpg -------------------------------------------------------------------------------- /example/qingming2_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/qingming2_label.jpg -------------------------------------------------------------------------------- /example/train_mask_motion.yaml: -------------------------------------------------------------------------------- 1 | # Pretrained diffusers model path. 2 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main 3 | pretrained_model_path: "output/latent/animate_anything_512_v1.02" 4 | #pretrained_model_path: "output/latent/train_4fps" 5 | #pretrained_model_path: "/data/llm/zeroscope_v2_576w" 6 | 7 | motion_mask: True 8 | motion_strength: True 9 | 10 | # The folder where your training outputs will be placed. 11 | output_dir: "./output/latent" 12 | 13 | # You can train multiple datasets at once. They will be joined together for training. 14 | # Simply remove the line you don't need, or keep them all for mixed training. 15 | 16 | # 'image': A folder of images and captions (.txt) 17 | # 'folder': A folder a videos and captions (.txt) 18 | # 'json': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor 19 | # 'video_json': a video foler and a json caption file 20 | # 'single_video': A single video file.mp4 and text prompt 21 | dataset_types: 22 | #- 'single_video' 23 | #- 'folder' 24 | #- 'image' 25 | - 'video_blip' 26 | #- 'video_json' 27 | 28 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise 29 | # If this is enabled, rescale_schedule will be disabled. 30 | offset_noise_strength: 0.1 31 | use_offset_noise: False 32 | 33 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf 34 | # If this is enabled, offset noise will be disabled. 35 | rescale_schedule: True 36 | 37 | # When True, this extends all items in all enabled datasets to the highest length. 38 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200. 39 | extend_dataset: False 40 | 41 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD. 42 | # The latents will be saved under your training folder, and loaded automatically for training. 43 | # This both saves memory and speeds up training and takes very little disk space. 44 | cache_latents: False 45 | 46 | # If you have cached latents set to `True` and have a directory of cached latents, 47 | # you can skip the caching process and load previously saved ones. 48 | cached_latent_dir: null #/path/to/cached_latents 49 | 50 | # Train the text encoder for the model. LoRA Training overrides this setting. 51 | train_text_encoder: False 52 | 53 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension) 54 | # This is the first, original implementation of LoRA by cloneofsimo. 55 | # Use this version if you want to maintain compatibility to the original version. 56 | 57 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension) 58 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here. 59 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv). 60 | # This means that there isn't an extra function when doing low ranking adaption. 61 | # It solely saves the weight differential between the initialized weights and updates. 62 | 63 | # "cloneofsimo" or "stable_lora" 64 | lora_version: "cloneofsimo" 65 | 66 | # Use LoRA for the UNET model. 67 | use_unet_lora: False 68 | 69 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained. 70 | use_text_lora: False 71 | 72 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting. 73 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html 74 | lora_unet_dropout: 0.1 75 | 76 | lora_text_dropout: 0.1 77 | 78 | # https://github.com/kabachuha/sd-webui-text2video 79 | # This saves a LoRA that is compatible with the text2video webui extension. 80 | # It only works when the lora version is 'stable_lora'. 81 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation. 82 | save_lora_for_webui: True 83 | 84 | # The LoRA file will be converted to a different format to be compatible with the webui extension. 85 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model 86 | # when this version is set to False 87 | only_lora_for_webui: False 88 | 89 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training. 90 | # The only time you want this off is if you're doing full LoRA training. 91 | save_pretrained_model: True 92 | 93 | # The modules to use for LoRA. Different from 'trainable_modules'. 94 | unet_lora_modules: 95 | - "UNet3DConditionModel" 96 | #- "ResnetBlock2D" 97 | #- "TransformerTemporalModel" 98 | #- "Transformer2DModel" 99 | #- "CrossAttention" 100 | #- "Attention" 101 | #- "GEGLU" 102 | #- "TemporalConvLayer" 103 | 104 | # The modules to use for LoRA. Different from `trainable_text_modules`. 105 | text_encoder_lora_modules: 106 | - "CLIPEncoderLayer" 107 | #- "CLIPAttention" 108 | 109 | # The rank for LoRA training. With ModelScope, the maximum should be 1024. 110 | # VRAM increases with higher rank, lower when decreased. 111 | lora_rank: 16 112 | 113 | # Training data parameters 114 | train_data: 115 | width: 512 116 | height: 512 117 | use_bucketing: False 118 | return_mask: True 119 | return_motion: True 120 | sample_start_idx: 1 121 | fps: 8 122 | n_sample_frames: 16 123 | json_path: '/webvid/animation0.json' 124 | 125 | # Validation data parameters. 126 | validation_data: 127 | 128 | # A custom prompt that is different from your training dataset. 129 | prompt: "a girl moves hands" 130 | 131 | prompt_image: "output/example/girl4.jpg" 132 | 133 | # Whether or not to sample preview during training (Requires more VRAM). 134 | sample_preview: True 135 | 136 | # The number of frames to sample during validation. 137 | num_frames: 16 138 | 139 | # Height and width of validation sample. 140 | width: 512 141 | height: 512 142 | 143 | # Number of inference steps when generating the video. 144 | num_inference_steps: 25 145 | 146 | # CFG scale 147 | guidance_scale: 9 148 | 149 | # Learning rate for AdamW 150 | learning_rate: 5.0e-06 151 | 152 | # Weight decay. Higher = more regularization. Lower = closer to dataset. 153 | adam_weight_decay: 0 154 | 155 | # Optimizer parameters for the UNET. Overrides base learning rate parameters. 156 | extra_unet_params: null 157 | #learning_rate: 1e-5 158 | #adam_weight_decay: 1e-4 159 | 160 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters. 161 | extra_text_encoder_params: null 162 | #learning_rate: 1e-4 163 | #adam_weight_decay: 0.2 164 | 165 | # How many batches to train. Not to be confused with video frames. 166 | train_batch_size: 8 167 | # Maximum number of train steps. Model is saved after training. 168 | max_train_steps: 5000 169 | 170 | # Saves a model every nth step. 171 | checkpointing_steps: 1000 172 | 173 | # How many steps to do for validation if sample_preview is enabled. 174 | validation_steps: 200 175 | 176 | # Which modules we want to unfreeze for the UNET. Advanced usage. 177 | trainable_modules: 178 | #- "all" 179 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions" 180 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1 181 | - "attn1" 182 | - ".attentions" 183 | 184 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1 185 | - "attn2" 186 | - "conv_in" 187 | 188 | # Convolution networks that hold temporal information. Activates for spatial and temporal dimensions if n_sample_frames > 1 189 | - "temp_conv" 190 | - "motion" 191 | 192 | 193 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage. 194 | trainable_text_modules: null 195 | 196 | # Seed for validation. 197 | seed: null 198 | 199 | # Whether or not we want to use mixed precision with accelerate 200 | mixed_precision: "fp16" 201 | 202 | # This seems to be incompatible at the moment. 203 | use_8bit_adam: False 204 | 205 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM. 206 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2. 207 | gradient_checkpointing: True 208 | text_encoder_gradient_checkpointing: False 209 | 210 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0) 211 | enable_xformers_memory_efficient_attention: False 212 | 213 | # Use scaled dot product attention (Only available with >= Torch 2.0) 214 | enable_torch_2_attn: True 215 | -------------------------------------------------------------------------------- /example/train_mask_motion_lora.yaml: -------------------------------------------------------------------------------- 1 | # running scripts: 2 | # accelerate launch --config_file example/deepspeed.yaml train_lora.py --config example/train_mask_motion_lora.yaml 3 | # python train_lora.py --config example/train_mask_motion_lora.yaml --eval 4 | 5 | # Pretrained diffusers model path. 6 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main 7 | pretrained_model_path: "output/latent/animate_anything_512_v1.02" 8 | 9 | # pretrained lora path 10 | # lora_path is only valid during eval (--eval). 11 | # lora module is saved to {output_dir}/train_{datetime}/{checkpoint}/lora by default during training 12 | lora_path: "/path/to/your_lora_module" 13 | 14 | motion_mask: True 15 | motion_strength: True 16 | 17 | # The folder where your training outputs will be placed. 18 | output_dir: "./output/latent" 19 | 20 | # You can train multiple datasets at once. They will be joined together for training. 21 | # Simply remove the line you don't need, or keep them all for mixed training. 22 | 23 | # 'image': A folder of images and captions (.txt) 24 | # 'folder': A folder a videos and captions (.txt) 25 | # 'json': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor 26 | # 'video_json': a video foler and a json caption file 27 | # 'single_video': A single video file.mp4 and text prompt 28 | dataset_types: 29 | #- 'single_video' 30 | #- 'folder' 31 | #- 'image' 32 | - 'video_blip' 33 | #- 'video_json' 34 | 35 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise 36 | # If this is enabled, rescale_schedule will be disabled. 37 | offset_noise_strength: 0.1 38 | use_offset_noise: False 39 | 40 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf 41 | # If this is enabled, offset noise will be disabled. 42 | rescale_schedule: True 43 | 44 | # When True, this extends all items in all enabled datasets to the highest length. 45 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200. 46 | extend_dataset: False 47 | 48 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD. 49 | # The latents will be saved under your training folder, and loaded automatically for training. 50 | # This both saves memory and speeds up training and takes very little disk space. 51 | cache_latents: False 52 | 53 | # If you have cached latents set to `True` and have a directory of cached latents, 54 | # you can skip the caching process and load previously saved ones. 55 | cached_latent_dir: null #/path/to/cached_latents 56 | 57 | # Train the text encoder for the model. LoRA Training overrides this setting. 58 | train_text_encoder: False 59 | 60 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension) 61 | # This is the first, original implementation of LoRA by cloneofsimo. 62 | # Use this version if you want to maintain compatibility to the original version. 63 | 64 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension) 65 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here. 66 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv). 67 | # This means that there isn't an extra function when doing low ranking adaption. 68 | # It solely saves the weight differential between the initialized weights and updates. 69 | 70 | # "cloneofsimo" or "stable_lora" 71 | lora_version: "cloneofsimo" 72 | 73 | # Use LoRA for the UNET model. 74 | use_unet_lora: True 75 | 76 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained. 77 | use_text_lora: False 78 | 79 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting. 80 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html 81 | lora_unet_dropout: 0.1 82 | 83 | lora_text_dropout: 0.1 84 | 85 | # https://github.com/kabachuha/sd-webui-text2video 86 | # This saves a LoRA that is compatible with the text2video webui extension. 87 | # It only works when the lora version is 'stable_lora'. 88 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation. 89 | save_lora_for_webui: True 90 | 91 | # The LoRA file will be converted to a different format to be compatible with the webui extension. 92 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model 93 | # when this version is set to False 94 | only_lora_for_webui: False 95 | 96 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training. 97 | # The only time you want this off is if you're doing full LoRA training. 98 | save_pretrained_model: True 99 | 100 | # The modules to use for LoRA. Different from 'trainable_modules'. 101 | unet_lora_modules: 102 | - "UNet3DConditionModel" 103 | #- "ResnetBlock2D" 104 | #- "TransformerTemporalModel" 105 | #- "Transformer2DModel" 106 | #- "CrossAttention" 107 | #- "Attention" 108 | #- "GEGLU" 109 | #- "TemporalConvLayer" 110 | 111 | # The modules to use for LoRA. Different from `trainable_text_modules`. 112 | text_encoder_lora_modules: 113 | - "CLIPEncoderLayer" 114 | #- "CLIPAttention" 115 | 116 | # The rank for LoRA training. With ModelScope, the maximum should be 1024. 117 | # VRAM increases with higher rank, lower when decreased. 118 | lora_rank: 16 119 | 120 | # Training data parameters 121 | train_data: 122 | width: 512 123 | height: 512 124 | use_bucketing: False 125 | return_mask: True 126 | return_motion: True 127 | sample_start_idx: 1 128 | fps: 8 129 | n_sample_frames: 16 130 | json_path: '/webvid/animation0.json' 131 | 132 | # Validation data parameters. 133 | validation_data: 134 | 135 | # A custom prompt that is different from your training dataset. 136 | prompt: "a girl smiling" 137 | 138 | prompt_image: "example/barbie.jpg" 139 | 140 | # Whether or not to sample preview during training (Requires more VRAM). 141 | sample_preview: True 142 | 143 | # The number of frames to sample during validation. 144 | num_frames: 16 145 | 146 | # Height and width of validation sample. 147 | width: 512 148 | height: 512 149 | 150 | # Number of inference steps when generating the video. 151 | num_inference_steps: 25 152 | 153 | # CFG scale 154 | guidance_scale: 9 155 | 156 | # Learning rate for AdamW 157 | learning_rate: 5.0e-06 158 | 159 | # Weight decay. Higher = more regularization. Lower = closer to dataset. 160 | adam_weight_decay: 0 161 | 162 | # Optimizer parameters for the UNET. Overrides base learning rate parameters. 163 | extra_unet_params: null 164 | #learning_rate: 1e-5 165 | #adam_weight_decay: 1e-4 166 | 167 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters. 168 | extra_text_encoder_params: null 169 | #learning_rate: 1e-4 170 | #adam_weight_decay: 0.2 171 | 172 | # How many batches to train. Not to be confused with video frames. 173 | train_batch_size: 4 174 | # Maximum number of train steps. Model is saved after training. 175 | max_train_steps: 1000 176 | 177 | # Saves a model every nth step. 178 | checkpointing_steps: 100 179 | 180 | # How many steps to do for validation if sample_preview is enabled. 181 | validation_steps: 100 182 | 183 | # Which modules we want to unfreeze for the UNET. Advanced usage. 184 | # trainable_modules: 185 | # - "None" 186 | 187 | 188 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage. 189 | trainable_text_modules: null 190 | 191 | # Seed for validation. 192 | seed: null 193 | 194 | # Whether or not we want to use mixed precision with accelerate 195 | mixed_precision: "fp16" 196 | 197 | # This seems to be incompatible at the moment. 198 | use_8bit_adam: False 199 | 200 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM. 201 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2. 202 | gradient_checkpointing: True 203 | text_encoder_gradient_checkpointing: False 204 | 205 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0) 206 | enable_xformers_memory_efficient_attention: False 207 | 208 | # Use scaled dot product attention (Only available with >= Torch 2.0) 209 | enable_torch_2_attn: True 210 | -------------------------------------------------------------------------------- /example/train_svd.yaml: -------------------------------------------------------------------------------- 1 | # Pretrained diffusers model path. 2 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main 3 | pretrained_model_path: "/webvid/llm/stable-video-diffusion-img2vid" 4 | 5 | 6 | motion_mask: False 7 | motion_strength: False 8 | 9 | # The folder where your training outputs will be placed. 10 | output_dir: "./output/svd" 11 | 12 | # You can train multiple datasets at once. They will be joined together for training. 13 | # Simply remove the line you don't need, or keep them all for mixed training. 14 | 15 | # 'image': A folder of images and captions (.txt) 16 | # 'folder': A folder a videos and captions (.txt) 17 | # 'video_blip': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor 18 | # 'video_json': a video foler and a json caption file 19 | # 'single_video': A single video file.mp4 and text prompt 20 | dataset_types: 21 | #- 'single_video' 22 | #- 'folder' 23 | #- 'image' 24 | - 'video_blip' 25 | #- 'video_json' 26 | 27 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise 28 | # If this is enabled, rescale_schedule will be disabled. 29 | offset_noise_strength: 0.1 30 | use_offset_noise: False 31 | 32 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf 33 | # If this is enabled, offset noise will be disabled. 34 | rescale_schedule: False 35 | 36 | # When True, this extends all items in all enabled datasets to the highest length. 37 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200. 38 | extend_dataset: False 39 | 40 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD. 41 | # The latents will be saved under your training folder, and loaded automatically for training. 42 | # This both saves memory and speeds up training and takes very little disk space. 43 | cache_latents: False 44 | 45 | # If you have cached latents set to `True` and have a directory of cached latents, 46 | # you can skip the caching process and load previously saved ones. 47 | cached_latent_dir: null #/path/to/cached_latents 48 | 49 | # Train the text encoder for the model. LoRA Training overrides this setting. 50 | train_text_encoder: False 51 | 52 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension) 53 | # This is the first, original implementation of LoRA by cloneofsimo. 54 | # Use this version if you want to maintain compatibility to the original version. 55 | 56 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension) 57 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here. 58 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv). 59 | # This means that there isn't an extra function when doing low ranking adaption. 60 | # It solely saves the weight differential between the initialized weights and updates. 61 | 62 | # "cloneofsimo" or "stable_lora" 63 | lora_version: "cloneofsimo" 64 | 65 | # Use LoRA for the UNET model. 66 | use_unet_lora: False 67 | 68 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained. 69 | use_text_lora: False 70 | 71 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting. 72 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html 73 | lora_unet_dropout: 0.1 74 | 75 | lora_text_dropout: 0.1 76 | 77 | # https://github.com/kabachuha/sd-webui-text2video 78 | # This saves a LoRA that is compatible with the text2video webui extension. 79 | # It only works when the lora version is 'stable_lora'. 80 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation. 81 | save_lora_for_webui: True 82 | 83 | # The LoRA file will be converted to a different format to be compatible with the webui extension. 84 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model 85 | # when this version is set to False 86 | only_lora_for_webui: False 87 | 88 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training. 89 | # The only time you want this off is if you're doing full LoRA training. 90 | save_pretrained_model: True 91 | 92 | # The modules to use for LoRA. Different from 'trainable_modules'. 93 | unet_lora_modules: 94 | - "UNet3DConditionModel" 95 | #- "ResnetBlock2D" 96 | #- "TransformerTemporalModel" 97 | #- "Transformer2DModel" 98 | #- "CrossAttention" 99 | #- "Attention" 100 | #- "GEGLU" 101 | #- "TemporalConvLayer" 102 | 103 | # The modules to use for LoRA. Different from `trainable_text_modules`. 104 | text_encoder_lora_modules: 105 | - "CLIPEncoderLayer" 106 | #- "CLIPAttention" 107 | 108 | # The rank for LoRA training. With ModelScope, the maximum should be 1024. 109 | # VRAM increases with higher rank, lower when decreased. 110 | lora_rank: 16 111 | 112 | # Training data parameters 113 | train_data: 114 | 115 | # The width and height in which you want your training data to be resized to. 116 | width: 512 117 | height: 512 118 | 119 | # This will find the closest aspect ratio to your input width and height. 120 | # For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256 121 | use_bucketing: False 122 | return_mask: True 123 | return_motion: True 124 | 125 | # The start frame index where your videos should start (Leave this at one for json and folder based training). 126 | sample_start_idx: 1 127 | 128 | # Used for 'folder'. The rate at which your frames are sampled. Does nothing for 'json' and 'single_video' dataset. 129 | # high fps, lower frame step, move slowly 130 | fps: 7 131 | 132 | # For 'single_video' and 'json'. The number of frames to "step" (1,2,3,4) (frame_step=2) -> (1,3,5,7, ...). 133 | frame_step: 1 134 | 135 | # The number of frames to sample. The higher this number, the higher the VRAM (acts similar to batch size). 136 | n_sample_frames: 7 137 | 138 | # 'single_video' 139 | single_video_path: "/data/datasets/animal_kingdom/video_grounding/dataset/AADJBFXO.mp4" 140 | 141 | # The prompt when using a a single video file 142 | single_video_prompt: "a bird" 143 | 144 | # Fallback prompt if caption cannot be read. Enabled for 'image' and 'folder'. 145 | fallback_prompt: '' 146 | 147 | # 'folder' 148 | #path: "/data2/webvid/data/videos/004151_004200" 149 | path: "/data/datasets/msvd/videos_mp4" 150 | 151 | # 'json' 152 | json_path: '/webvid/animation1.json' 153 | 154 | # 'image' 155 | image_dir: '/vlp/datasets/images/coco' 156 | image_json: '/vlp/datasets/images/coco/coco_karpathy_train.json' 157 | 158 | video_dir: '/webvid/webvid/data/videos' 159 | video_json: '/webvid/webvid/data/40K.json' 160 | # The prompt for all image files. Leave blank to use caption files (.txt) 161 | single_img_prompt: "" 162 | 163 | 164 | # Validation data parameters. 165 | validation_data: 166 | 167 | # A custom prompt that is different from your training dataset. 168 | prompt: "a girl moves body" 169 | 170 | prompt_image: "output/example/girl4.jpg" 171 | 172 | # Whether or not to sample preview during training (Requires more VRAM). 173 | sample_preview: True 174 | 175 | # The number of frames to sample during validation. 176 | num_frames: 14 177 | 178 | # Height and width of validation sample. 179 | width: 512 180 | height: 512 181 | 182 | # Number of inference steps when generating the video. 183 | num_inference_steps: 25 184 | 185 | # CFG scale 186 | guidance_scale: 9 187 | 188 | # fps 189 | fps: 7 190 | 191 | # The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. 192 | motion_bucket_id: 127 193 | 194 | # The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency between frames, 195 | # but also the higher the memory consumption. By default, the decoder will decode all frames at once for maximal quality. 196 | # Reduce `decode_chunk_size` to reduce memory usage. 197 | decode_chunk_size: 7 198 | 199 | # Learning rate for AdamW 200 | learning_rate: 5e-6 201 | 202 | # Weight decay. Higher = more regularization. Lower = closer to dataset. 203 | adam_weight_decay: 0 204 | 205 | # Optimizer parameters for the UNET. Overrides base learning rate parameters. 206 | extra_unet_params: null 207 | #learning_rate: 1e-5 208 | #adam_weight_decay: 1e-4 209 | 210 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters. 211 | extra_text_encoder_params: null 212 | #learning_rate: 1e-4 213 | #adam_weight_decay: 0.2 214 | 215 | # How many batches to train. Not to be confused with video frames. 216 | train_batch_size: 1 217 | # Maximum number of train steps. Model is saved after training. 218 | max_train_steps: 10000 219 | 220 | # Saves a model every nth step. 221 | checkpointing_steps: 2500 222 | 223 | # How many steps to do for validation if sample_preview is enabled. 224 | validation_steps: 300 225 | 226 | # Which modules we want to unfreeze for the UNET. Advanced usage. 227 | trainable_modules: 228 | - "all" 229 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions" 230 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1 231 | - "attn1" 232 | #- ".attentions" 233 | 234 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1 235 | - "attn2" 236 | - "conv_in" 237 | 238 | # Convolution networks that hold temporal information. Activates for spatial and temporal dimensions if n_sample_frames > 1 239 | - "temp_conv" 240 | - "motion" 241 | 242 | 243 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage. 244 | trainable_text_modules: null 245 | 246 | # Seed for validation. 247 | seed: 6 248 | 249 | # Whether or not we want to use mixed precision with accelerate 250 | mixed_precision: "fp16" 251 | 252 | # This seems to be incompatible at the moment. 253 | use_8bit_adam: False 254 | 255 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM. 256 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2. 257 | gradient_checkpointing: True 258 | text_encoder_gradient_checkpointing: False 259 | 260 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0) 261 | enable_xformers_memory_efficient_attention: False 262 | 263 | # Use scaled dot product attention (Only available with >= Torch 2.0) 264 | enable_torch_2_attn: True 265 | -------------------------------------------------------------------------------- /example/train_svd_mask.yaml: -------------------------------------------------------------------------------- 1 | # Pretrained diffusers model path. 2 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main 3 | pretrained_model_path: "/webvid/llm/stable-video-diffusion-img2vid-mask" 4 | 5 | motion_mask: True 6 | motion_strength: False 7 | 8 | # The folder where your training outputs will be placed. 9 | output_dir: "./output/svd" 10 | 11 | # You can train multiple datasets at once. They will be joined together for training. 12 | # Simply remove the line you don't need, or keep them all for mixed training. 13 | 14 | # 'image': A folder of images and captions (.txt) 15 | # 'folder': A folder a videos and captions (.txt) 16 | # 'video_blip': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor 17 | # 'video_json': a video foler and a json caption file 18 | # 'single_video': A single video file.mp4 and text prompt 19 | dataset_types: 20 | #- 'single_video' 21 | #- 'folder' 22 | #- 'image' 23 | - 'video_blip' 24 | #- 'video_json' 25 | 26 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise 27 | # If this is enabled, rescale_schedule will be disabled. 28 | offset_noise_strength: 0.1 29 | use_offset_noise: False 30 | 31 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf 32 | # If this is enabled, offset noise will be disabled. 33 | rescale_schedule: False 34 | 35 | # When True, this extends all items in all enabled datasets to the highest length. 36 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200. 37 | extend_dataset: False 38 | 39 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD. 40 | # The latents will be saved under your training folder, and loaded automatically for training. 41 | # This both saves memory and speeds up training and takes very little disk space. 42 | cache_latents: False 43 | 44 | # If you have cached latents set to `True` and have a directory of cached latents, 45 | # you can skip the caching process and load previously saved ones. 46 | cached_latent_dir: null #/path/to/cached_latents 47 | 48 | # Train the text encoder for the model. LoRA Training overrides this setting. 49 | train_text_encoder: False 50 | 51 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension) 52 | # This is the first, original implementation of LoRA by cloneofsimo. 53 | # Use this version if you want to maintain compatibility to the original version. 54 | 55 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension) 56 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here. 57 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv). 58 | # This means that there isn't an extra function when doing low ranking adaption. 59 | # It solely saves the weight differential between the initialized weights and updates. 60 | 61 | # "cloneofsimo" or "stable_lora" 62 | lora_version: "cloneofsimo" 63 | 64 | # Use LoRA for the UNET model. 65 | use_unet_lora: False 66 | 67 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained. 68 | use_text_lora: False 69 | 70 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting. 71 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html 72 | lora_unet_dropout: 0.1 73 | 74 | lora_text_dropout: 0.1 75 | 76 | # https://github.com/kabachuha/sd-webui-text2video 77 | # This saves a LoRA that is compatible with the text2video webui extension. 78 | # It only works when the lora version is 'stable_lora'. 79 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation. 80 | save_lora_for_webui: True 81 | 82 | # The LoRA file will be converted to a different format to be compatible with the webui extension. 83 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model 84 | # when this version is set to False 85 | only_lora_for_webui: False 86 | 87 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training. 88 | # The only time you want this off is if you're doing full LoRA training. 89 | save_pretrained_model: True 90 | 91 | # The modules to use for LoRA. Different from 'trainable_modules'. 92 | unet_lora_modules: 93 | - "UNet3DConditionModel" 94 | #- "ResnetBlock2D" 95 | #- "TransformerTemporalModel" 96 | #- "Transformer2DModel" 97 | #- "CrossAttention" 98 | #- "Attention" 99 | #- "GEGLU" 100 | #- "TemporalConvLayer" 101 | 102 | # The modules to use for LoRA. Different from `trainable_text_modules`. 103 | text_encoder_lora_modules: 104 | - "CLIPEncoderLayer" 105 | #- "CLIPAttention" 106 | 107 | # The rank for LoRA training. With ModelScope, the maximum should be 1024. 108 | # VRAM increases with higher rank, lower when decreased. 109 | lora_rank: 16 110 | 111 | # Training data parameters 112 | train_data: 113 | 114 | # The width and height in which you want your training data to be resized to. 115 | width: 512 116 | height: 512 117 | 118 | # This will find the closest aspect ratio to your input width and height. 119 | # For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256 120 | use_bucketing: False 121 | return_mask: True 122 | return_motion: True 123 | 124 | # The start frame index where your videos should start (Leave this at one for json and folder based training). 125 | sample_start_idx: 1 126 | 127 | # Used for 'folder'. The rate at which your frames are sampled. Does nothing for 'json' and 'single_video' dataset. 128 | # high fps, lower frame step, move slowly 129 | fps: 7 130 | 131 | # For 'single_video' and 'json'. The number of frames to "step" (1,2,3,4) (frame_step=2) -> (1,3,5,7, ...). 132 | frame_step: 1 133 | 134 | # The number of frames to sample. The higher this number, the higher the VRAM (acts similar to batch size). 135 | n_sample_frames: 12 136 | 137 | # 'single_video' 138 | single_video_path: "/data/datasets/animal_kingdom/video_grounding/dataset/AADJBFXO.mp4" 139 | 140 | # The prompt when using a a single video file 141 | single_video_prompt: "a bird" 142 | 143 | # Fallback prompt if caption cannot be read. Enabled for 'image' and 'folder'. 144 | fallback_prompt: '' 145 | 146 | # 'folder' 147 | #path: "/data2/webvid/data/videos/004151_004200" 148 | path: "/data/datasets/msvd/videos_mp4" 149 | 150 | # 'json' 151 | json_path: '/webvid/animation1.json' 152 | 153 | # 'image' 154 | image_dir: '/vlp/datasets/images/coco' 155 | image_json: '/vlp/datasets/images/coco/coco_karpathy_train.json' 156 | 157 | video_dir: '/webvid/webvid/data/videos' 158 | video_json: '/webvid/webvid/data/40K.json' 159 | # The prompt for all image files. Leave blank to use caption files (.txt) 160 | single_img_prompt: "" 161 | 162 | 163 | # Validation data parameters. 164 | validation_data: 165 | 166 | # A custom prompt that is different from your training dataset. 167 | prompt: "a girl moves body" 168 | 169 | prompt_image: "output/example/fish1.jpg" 170 | 171 | # Whether or not to sample preview during training (Requires more VRAM). 172 | sample_preview: True 173 | 174 | # The number of frames to sample during validation. 175 | num_frames: 14 176 | 177 | # Height and width of validation sample. 178 | width: 512 179 | height: 512 180 | 181 | # Number of inference steps when generating the video. 182 | num_inference_steps: 25 183 | 184 | # CFG scale 185 | guidance_scale: 9 186 | 187 | # fps 188 | fps: 7 189 | 190 | # The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. 191 | motion_bucket_id: 127 192 | 193 | # The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency between frames, 194 | # but also the higher the memory consumption. By default, the decoder will decode all frames at once for maximal quality. 195 | # Reduce `decode_chunk_size` to reduce memory usage. 196 | decode_chunk_size: 7 197 | 198 | # Learning rate for AdamW 199 | learning_rate: 2e-5 200 | 201 | # Weight decay. Higher = more regularization. Lower = closer to dataset. 202 | adam_weight_decay: 0 203 | 204 | # Optimizer parameters for the UNET. Overrides base learning rate parameters. 205 | extra_unet_params: null 206 | #learning_rate: 1e-5 207 | #adam_weight_decay: 1e-4 208 | 209 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters. 210 | extra_text_encoder_params: null 211 | #learning_rate: 1e-4 212 | #adam_weight_decay: 0.2 213 | 214 | # How many batches to train. Not to be confused with video frames. 215 | train_batch_size: 3 216 | # Maximum number of train steps. Model is saved after training. 217 | max_train_steps: 20000 218 | 219 | # Saves a model every nth step. 220 | checkpointing_steps: 2500 221 | 222 | # How many steps to do for validation if sample_preview is enabled. 223 | validation_steps: 100 224 | 225 | # Which modules we want to unfreeze for the UNET. Advanced usage. 226 | trainable_modules: 227 | - "all" 228 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions" 229 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1 230 | - "attn1" 231 | - ".attentions" 232 | 233 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1 234 | - "attn2" 235 | - "conv_in" 236 | 237 | # Convolution networks that hold temporal information. Activates for spatial and temporal dimensions if n_sample_frames > 1 238 | - "temp_conv" 239 | - "motion" 240 | 241 | 242 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage. 243 | trainable_text_modules: null 244 | 245 | # Seed for validation. 246 | seed: 6 247 | 248 | # Whether or not we want to use mixed precision with accelerate 249 | mixed_precision: "fp16" 250 | 251 | # This seems to be incompatible at the moment. 252 | use_8bit_adam: False 253 | 254 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM. 255 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2. 256 | gradient_checkpointing: True 257 | text_encoder_gradient_checkpointing: False 258 | 259 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0) 260 | enable_xformers_memory_efficient_attention: False 261 | 262 | # Use scaled dot product attention (Only available with >= Torch 2.0) 263 | enable_torch_2_attn: True 264 | -------------------------------------------------------------------------------- /example/train_svd_v2v.yaml: -------------------------------------------------------------------------------- 1 | # Pretrained diffusers model path. 2 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main 3 | pretrained_model_path: "/webvid/llm/stable-video-diffusion-img2vid" 4 | 5 | 6 | motion_mask: True 7 | motion_strength: False 8 | 9 | # The folder where your training outputs will be placed. 10 | output_dir: "./output/svd" 11 | 12 | # You can train multiple datasets at once. They will be joined together for training. 13 | # Simply remove the line you don't need, or keep them all for mixed training. 14 | 15 | # 'image': A folder of images and captions (.txt) 16 | # 'folder': A folder a videos and captions (.txt) 17 | # 'json': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor 18 | # 'video_json': a video foler and a json caption file 19 | # 'single_video': A single video file.mp4 and text prompt 20 | dataset_types: 21 | #- 'single_video' 22 | #- 'folder' 23 | #- 'image' 24 | - 'video_blip' 25 | #- 'video_json' 26 | 27 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise 28 | # If this is enabled, rescale_schedule will be disabled. 29 | offset_noise_strength: 0.1 30 | use_offset_noise: False 31 | 32 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf 33 | # If this is enabled, offset noise will be disabled. 34 | rescale_schedule: False 35 | 36 | # When True, this extends all items in all enabled datasets to the highest length. 37 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200. 38 | extend_dataset: False 39 | 40 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD. 41 | # The latents will be saved under your training folder, and loaded automatically for training. 42 | # This both saves memory and speeds up training and takes very little disk space. 43 | cache_latents: False 44 | 45 | # If you have cached latents set to `True` and have a directory of cached latents, 46 | # you can skip the caching process and load previously saved ones. 47 | cached_latent_dir: null #/path/to/cached_latents 48 | 49 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension) 50 | # This is the first, original implementation of LoRA by cloneofsimo. 51 | # Use this version if you want to maintain compatibility to the original version. 52 | 53 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension) 54 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here. 55 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv). 56 | # This means that there isn't an extra function when doing low ranking adaption. 57 | # It solely saves the weight differential between the initialized weights and updates. 58 | 59 | # "cloneofsimo" or "stable_lora" 60 | lora_version: "cloneofsimo" 61 | 62 | # Use LoRA for the UNET model. 63 | use_unet_lora: False 64 | 65 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained. 66 | use_text_lora: False 67 | 68 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting. 69 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html 70 | lora_unet_dropout: 0.1 71 | 72 | lora_text_dropout: 0.1 73 | 74 | # https://github.com/kabachuha/sd-webui-text2video 75 | # This saves a LoRA that is compatible with the text2video webui extension. 76 | # It only works when the lora version is 'stable_lora'. 77 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation. 78 | save_lora_for_webui: True 79 | 80 | # The LoRA file will be converted to a different format to be compatible with the webui extension. 81 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model 82 | # when this version is set to False 83 | only_lora_for_webui: False 84 | 85 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training. 86 | # The only time you want this off is if you're doing full LoRA training. 87 | save_pretrained_model: True 88 | 89 | # The modules to use for LoRA. Different from 'trainable_modules'. 90 | unet_lora_modules: 91 | - "UNet3DConditionModel" 92 | #- "ResnetBlock2D" 93 | #- "TransformerTemporalModel" 94 | #- "Transformer2DModel" 95 | #- "CrossAttention" 96 | #- "Attention" 97 | #- "GEGLU" 98 | #- "TemporalConvLayer" 99 | 100 | # The modules to use for LoRA. Different from `trainable_text_modules`. 101 | text_encoder_lora_modules: 102 | - "CLIPEncoderLayer" 103 | #- "CLIPAttention" 104 | 105 | # The rank for LoRA training. With ModelScope, the maximum should be 1024. 106 | # VRAM increases with higher rank, lower when decreased. 107 | lora_rank: 16 108 | 109 | # Training data parameters 110 | train_data: 111 | 112 | # The width and height in which you want your training data to be resized to. 113 | width: 576 114 | height: 1024 115 | 116 | # This will find the closest aspect ratio to your input width and height. 117 | # For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256 118 | use_bucketing: False 119 | return_mask: True 120 | return_motion: True 121 | 122 | # The start frame index where your videos should start (Leave this at one for json and folder based training). 123 | sample_start_idx: 1 124 | 125 | # Used for 'folder'. The rate at which your frames are sampled. Does nothing for 'json' and 'single_video' dataset. 126 | # high fps, lower frame step, move slowly 127 | fps: 7 128 | 129 | # For 'single_video' and 'json'. The number of frames to "step" (1,2,3,4) (frame_step=2) -> (1,3,5,7, ...). 130 | frame_step: 1 131 | 132 | # The number of frames to sample. The higher this number, the higher the VRAM (acts similar to batch size). 133 | n_sample_frames: 14 134 | 135 | # 'single_video' 136 | single_video_path: "/data/datasets/animal_kingdom/video_grounding/dataset/AADJBFXO.mp4" 137 | 138 | # The prompt when using a a single video file 139 | single_video_prompt: "a bird" 140 | 141 | # Fallback prompt if caption cannot be read. Enabled for 'image' and 'folder'. 142 | fallback_prompt: '' 143 | 144 | # 'folder' 145 | #path: "/data2/webvid/data/videos/004151_004200" 146 | path: "/data/datasets/msvd/videos_mp4" 147 | 148 | # 'json' 149 | json_path: '/webvid/animation2.json' 150 | 151 | # 'image' 152 | image_dir: '/vlp/datasets/images/coco' 153 | image_json: '/vlp/datasets/images/coco/coco_karpathy_train.json' 154 | 155 | #video_dir: '/mnt/cap/zuozhuo/webvid/webvid/data/videos' 156 | #video_json: '/mnt/cap/zuozhuo/webvid/webvid/data/1M.json' 157 | # The prompt for all image files. Leave blank to use caption files (.txt) 158 | single_img_prompt: "" 159 | 160 | video_dir: '/webvid/webvid/data/videos' 161 | video_json: '/webvid/webvid/data/40K.json' 162 | 163 | extra_train_data: 164 | - dataset_types: 165 | - video_blip 166 | train_data: 167 | json_path: '/webvid/animation_dataset_clips_part_0.json' 168 | - dataset_types: 169 | - video_blip 170 | train_data: 171 | json_path: '/webvid/animation_dataset_clips_part_1.json' 172 | - dataset_types: 173 | - video_blip 174 | train_data: 175 | json_path: '/webvid/animation_dataset_clips_part_2.json' 176 | - dataset_types: 177 | - video_blip 178 | train_data: 179 | json_path: '/webvid/animation0.json' 180 | - dataset_types: 181 | - video_blip 182 | train_data: 183 | json_path: '/webvid/animation1.json' 184 | 185 | 186 | # Validation data parameters. 187 | validation_data: 188 | 189 | # A custom prompt that is different from your training dataset. 190 | prompt: "The fish is swimming." 191 | 192 | prompt_image: "output/example/fish_512.mp4" 193 | 194 | # Whether or not to sample preview during training (Requires more VRAM). 195 | sample_preview: True 196 | 197 | # The number of frames to sample during validation. 198 | num_frames: 14 199 | 200 | # Height and width of validation sample. 201 | width: 512 202 | height: 512 203 | # Number of inference steps when generating the video. 204 | num_inference_steps: 25 205 | fps: 7 206 | # CFG scale 207 | guidance_scale: 3 208 | 209 | # Learning rate for AdamW 210 | learning_rate: 5e-6 211 | 212 | # Weight decay. Higher = more regularization. Lower = closer to dataset. 213 | adam_weight_decay: 0 214 | 215 | # Optimizer parameters for the UNET. Overrides base learning rate parameters. 216 | extra_unet_params: null 217 | #learning_rate: 1e-5 218 | #adam_weight_decay: 1e-4 219 | 220 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters. 221 | extra_text_encoder_params: null 222 | #learning_rate: 1e-4 223 | #adam_weight_decay: 0.2 224 | 225 | # How many batches to train. Not to be confused with video frames. 226 | train_batch_size: 1 227 | gradient_accumulation_steps: 2 228 | # Maximum number of train steps. Model is saved after training. 229 | max_train_steps: 10000 230 | 231 | # Saves a model every nth step. 232 | checkpointing_steps: 1000 233 | 234 | # How many steps to do for validation if sample_preview is enabled. 235 | validation_steps: 100 236 | 237 | # Which modules we want to unfreeze for the UNET. Advanced usage. 238 | trainable_modules: 239 | - "all" 240 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions" 241 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1 242 | #- "attn1" 243 | - ".attentions" 244 | 245 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1 246 | #- "attn2" 247 | 248 | # sample input and output 249 | - "conv_in" 250 | - "conv_out" 251 | 252 | # Time condition 253 | - '_proj' 254 | - '_embedding' 255 | 256 | 257 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage. 258 | trainable_text_modules: 259 | null 260 | #- "embedding" 261 | 262 | # Seed for validation. 263 | seed: 6 264 | 265 | # Whether or not we want to use mixed precision with accelerate 266 | mixed_precision: "fp16" 267 | 268 | # This seems to be incompatible at the moment. 269 | use_8bit_adam: False 270 | 271 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM. 272 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2. 273 | gradient_checkpointing: True 274 | text_encoder_gradient_checkpointing: False 275 | 276 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0) 277 | enable_xformers_memory_efficient_attention: False 278 | 279 | # Use scaled dot product attention (Only available with >= Torch 2.0) 280 | enable_torch_2_attn: True 281 | -------------------------------------------------------------------------------- /example/validation_file.json: -------------------------------------------------------------------------------- 1 | [ 2 | ["example/fish1.jpg", "The red fish is swimming"], 3 | ["example/barbie2.jpg", "a girl is talking, move head"], 4 | ["example/hulu3.jpg", "The man is talking, move hands."], 5 | ["example/pig0.jpg", "Three cartoon pigs are talking."], 6 | ["example/qingming2.jpg", "ships are sailing on the river."] 7 | ] 8 | 9 | -------------------------------------------------------------------------------- /models/layerdiffuse_VAE.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from typing import Optional, Tuple 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models.modeling_utils import ModelMixin 7 | from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block 8 | 9 | # referenced from https://github.com/layerdiffusion/sd-forge-layerdiffuse/blob/main/lib_layerdiffusion/models.py 10 | 11 | def zero_module(module): 12 | for p in module.parameters(): 13 | p.detach().zero_() 14 | return module 15 | 16 | 17 | class LatentTransparencyOffsetEncoder(torch.nn.Module): 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.blocks = torch.nn.Sequential( 21 | torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1), 22 | nn.SiLU(), 23 | torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), 24 | nn.SiLU(), 25 | torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2), 26 | nn.SiLU(), 27 | torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1), 28 | nn.SiLU(), 29 | torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2), 30 | nn.SiLU(), 31 | torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1), 32 | nn.SiLU(), 33 | torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), 34 | nn.SiLU(), 35 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), 36 | nn.SiLU(), 37 | zero_module(torch.nn.Conv2d(256, 4, kernel_size=3, padding=1, stride=1)), 38 | ) 39 | 40 | def __call__(self, x): 41 | return self.blocks(x) 42 | 43 | 44 | class UNet384(ModelMixin, ConfigMixin): 45 | @register_to_config 46 | def __init__( 47 | self, 48 | in_channels: int = 3, 49 | out_channels: int = 4, 50 | down_block_types: Tuple[str] = ("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"), 51 | up_block_types: Tuple[str] = ("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"), 52 | block_out_channels: Tuple[int] = (32, 64, 128, 256), 53 | layers_per_block: int = 2, 54 | mid_block_scale_factor: float = 1, 55 | downsample_padding: int = 1, 56 | downsample_type: str = "conv", 57 | upsample_type: str = "conv", 58 | dropout: float = 0.0, 59 | act_fn: str = "silu", 60 | attention_head_dim: Optional[int] = 8, 61 | norm_num_groups: int = 4, 62 | norm_eps: float = 1e-5, 63 | ): 64 | super().__init__() 65 | 66 | # input 67 | self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 68 | self.latent_conv_in = zero_module(nn.Conv2d(4, block_out_channels[2], kernel_size=1)) 69 | 70 | self.down_blocks = nn.ModuleList([]) 71 | self.mid_block = None 72 | self.up_blocks = nn.ModuleList([]) 73 | 74 | # down 75 | output_channel = block_out_channels[0] 76 | for i, down_block_type in enumerate(down_block_types): 77 | input_channel = output_channel 78 | output_channel = block_out_channels[i] 79 | is_final_block = i == len(block_out_channels) - 1 80 | 81 | down_block = get_down_block( 82 | down_block_type, 83 | num_layers=layers_per_block, 84 | in_channels=input_channel, 85 | out_channels=output_channel, 86 | temb_channels=None, 87 | add_downsample=not is_final_block, 88 | resnet_eps=norm_eps, 89 | resnet_act_fn=act_fn, 90 | resnet_groups=norm_num_groups, 91 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 92 | downsample_padding=downsample_padding, 93 | resnet_time_scale_shift="default", 94 | downsample_type=downsample_type, 95 | dropout=dropout, 96 | ) 97 | self.down_blocks.append(down_block) 98 | 99 | # mid 100 | self.mid_block = UNetMidBlock2D( 101 | in_channels=block_out_channels[-1], 102 | temb_channels=None, 103 | dropout=dropout, 104 | resnet_eps=norm_eps, 105 | resnet_act_fn=act_fn, 106 | output_scale_factor=mid_block_scale_factor, 107 | resnet_time_scale_shift="default", 108 | attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], 109 | resnet_groups=norm_num_groups, 110 | attn_groups=None, 111 | add_attention=True, 112 | ) 113 | 114 | # up 115 | reversed_block_out_channels = list(reversed(block_out_channels)) 116 | output_channel = reversed_block_out_channels[0] 117 | for i, up_block_type in enumerate(up_block_types): 118 | prev_output_channel = output_channel 119 | output_channel = reversed_block_out_channels[i] 120 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 121 | 122 | is_final_block = i == len(block_out_channels) - 1 123 | 124 | up_block = get_up_block( 125 | up_block_type, 126 | num_layers=layers_per_block + 1, 127 | in_channels=input_channel, 128 | out_channels=output_channel, 129 | prev_output_channel=prev_output_channel, 130 | temb_channels=None, 131 | add_upsample=not is_final_block, 132 | resnet_eps=norm_eps, 133 | resnet_act_fn=act_fn, 134 | resnet_groups=norm_num_groups, 135 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 136 | resnet_time_scale_shift="default", 137 | upsample_type=upsample_type, 138 | dropout=dropout, 139 | ) 140 | self.up_blocks.append(up_block) 141 | prev_output_channel = output_channel 142 | 143 | # out 144 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 145 | self.conv_act = nn.SiLU() 146 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 147 | 148 | def forward(self, x, latent): 149 | sample_latent = self.latent_conv_in(latent) 150 | sample = self.conv_in(x) 151 | emb = None 152 | 153 | down_block_res_samples = (sample,) 154 | for i, downsample_block in enumerate(self.down_blocks): 155 | # 8X downsample 156 | if i == 3: 157 | sample = sample + sample_latent 158 | 159 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 160 | down_block_res_samples += res_samples 161 | 162 | assert len(self.down_blocks) == 4 163 | 164 | sample = self.mid_block(sample, emb) 165 | 166 | for upsample_block in self.up_blocks: 167 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 168 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 169 | sample = upsample_block(sample, res_samples, emb) 170 | 171 | sample = self.conv_norm_out(sample) 172 | sample = self.conv_act(sample) 173 | sample = self.conv_out(sample) 174 | return sample 175 | 176 | def __call__(self, x, latent): 177 | return self.forward(x, latent) 178 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | torch==2.0.0 3 | torchvision 4 | diffusers==0.24.0 5 | transformers==4.36.2 6 | einops 7 | decord 8 | tqdm 9 | safetensors 10 | omegaconf 11 | opencv-python 12 | pydantic 13 | compel 14 | easydict 15 | rotary_embedding_torch 16 | imageio[ffmpeg] 17 | gradio 18 | httpx[socks] 19 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python train.py --config ./configs/v2/infer_config_latent.yaml --eval 2 | accelerate launch train.py --config configs/v2/train_config_latent.yaml 3 | 4 | -------------------------------------------------------------------------------- /stable_lora/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import loralib as loralb 6 | from loralib import LoRALayer 7 | import math 8 | import json 9 | 10 | from torch.utils.data import ConcatDataset 11 | from transformers import CLIPTokenizer 12 | 13 | try: 14 | from safetensors.torch import save_file, load_file 15 | except: 16 | print("Safetensors is not installed. Saving while using use_safetensors will fail.") 17 | 18 | UNET_REPLACE = ["Transformer2DModel", "ResnetBlock2D"] 19 | TEXT_ENCODER_REPLACE = ["CLIPAttention", "CLIPTextEmbeddings"] 20 | 21 | UNET_ATTENTION_REPLACE = ["CrossAttention"] 22 | TEXT_ENCODER_ATTENTION_REPLACE = ["CLIPAttention", "CLIPTextEmbeddings"] 23 | 24 | """ 25 | Copied from: https://github.com/cloneofsimo/lora/blob/bdd51b04c49fa90a88919a19850ec3b4cf3c5ecd/lora_diffusion/lora.py#L189 26 | """ 27 | def find_modules( 28 | model, 29 | ancestor_class= None, 30 | search_class = [torch.nn.Linear], 31 | exclude_children_of = [loralb.Linear, loralb.Conv2d, loralb.Embedding], 32 | ): 33 | """ 34 | Find all modules of a certain class (or union of classes) that are direct or 35 | indirect descendants of other modules of a certain class (or union of classes). 36 | 37 | Returns all matching modules, along with the parent of those moduless and the 38 | names they are referenced by. 39 | """ 40 | 41 | # Get the targets we should replace all linears under 42 | if ancestor_class is not None: 43 | ancestors = ( 44 | module 45 | for module in model.modules() 46 | if module.__class__.__name__ in ancestor_class 47 | ) 48 | else: 49 | # this, incase you want to naively iterate over all modules. 50 | ancestors = [module for module in model.modules()] 51 | 52 | # For each target find every linear_class module that isn't a child of a LoraInjectedLinear 53 | for ancestor in ancestors: 54 | for fullname, module in ancestor.named_modules(): 55 | if any([isinstance(module, _class) for _class in search_class]): 56 | # Find the direct parent if this is a descendant, not a child, of target 57 | *path, name = fullname.split(".") 58 | parent = ancestor 59 | while path: 60 | parent = parent.get_submodule(path.pop(0)) 61 | # Skip this linear if it's a child of a LoraInjectedLinear 62 | if exclude_children_of and any( 63 | [isinstance(parent, _class) for _class in exclude_children_of] 64 | ): 65 | continue 66 | # Otherwise, yield it 67 | yield parent, name, module 68 | 69 | class Conv2d(nn.Conv2d, LoRALayer): 70 | # LoRA implemented in a dense layer 71 | def __init__( 72 | self, 73 | in_channels: int, 74 | out_channels: int, 75 | kernel_size: int, 76 | r: int = 0, 77 | lora_alpha: int = 1, 78 | lora_dropout: float = 0., 79 | merge_weights: bool = True, 80 | **kwargs 81 | ): 82 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) 83 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, 84 | merge_weights=merge_weights) 85 | assert type(kernel_size) is int 86 | # Actual trainable parameters 87 | if r > 0: 88 | self.lora_A = nn.Parameter( 89 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) 90 | ) 91 | self.lora_B = nn.Parameter( 92 | self.weight.new_zeros((out_channels*kernel_size, r*kernel_size)) 93 | ) 94 | self.scaling = self.lora_alpha / self.r 95 | # Freezing the pre-trained weight matrix 96 | self.weight.requires_grad = False 97 | self.reset_parameters() 98 | 99 | def reset_parameters(self): 100 | nn.Conv2d.reset_parameters(self) 101 | if hasattr(self, 'lora_A'): 102 | # initialize A the same way as the default for nn.Linear and B to zero 103 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 104 | nn.init.zeros_(self.lora_B) 105 | 106 | def train(self, mode: bool = True): 107 | nn.Conv2d.train(self, mode) 108 | if mode: 109 | if self.merge_weights and self.merged: 110 | # Make sure that the weights are not merged 111 | self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling 112 | self.merged = False 113 | else: 114 | if self.merge_weights and not self.merged: 115 | # Merge the weights and mark it 116 | self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling 117 | self.merged = True 118 | 119 | def forward(self, x: torch.Tensor): 120 | if self.r > 0 and not self.merged: 121 | return F.conv2d( 122 | x, 123 | self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, 124 | self.bias, self.stride, self.padding, self.dilation, self.groups 125 | ) 126 | return nn.Conv2d.forward(self, x) 127 | 128 | class Conv3d(nn.Conv3d, LoRALayer): 129 | # LoRA implemented in a dense layer 130 | def __init__( 131 | self, 132 | in_channels: int, 133 | out_channels: int, 134 | kernel_size: int, 135 | r: int = 0, 136 | lora_alpha: int = 1, 137 | lora_dropout: float = 0., 138 | merge_weights: bool = True, 139 | **kwargs 140 | ): 141 | nn.Conv3d.__init__(self, in_channels, out_channels, (kernel_size, 1, 1), **kwargs) 142 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, 143 | merge_weights=merge_weights) 144 | assert type(kernel_size) is int 145 | # Actual trainable parameters 146 | 147 | # Get view transform shape 148 | i, o, k = self.weight.shape[:3] 149 | self.view_shape = (i, o, k, kernel_size, 1) 150 | self.force_disable_merge = True 151 | 152 | if r > 0: 153 | self.lora_A = nn.Parameter( 154 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) 155 | ) 156 | self.lora_B = nn.Parameter( 157 | self.weight.new_zeros((out_channels*kernel_size, r*kernel_size)) 158 | ) 159 | self.scaling = self.lora_alpha / self.r 160 | # Freezing the pre-trained weight matrix 161 | self.weight.requires_grad = False 162 | self.reset_parameters() 163 | 164 | def reset_parameters(self): 165 | nn.Conv3d.reset_parameters(self) 166 | if hasattr(self, 'lora_A'): 167 | # initialize A the same way as the default for nn.Linear and B to zero 168 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 169 | nn.init.zeros_(self.lora_B) 170 | 171 | def train(self, mode: bool = True): 172 | nn.Conv3d.train(self, mode) 173 | 174 | # HACK Merging the weights this way could potentially cause vanishing gradients if validation is enabled. 175 | # If you are to save this as a pretrained model, you will have to merge these weights afterwards, then save. 176 | if self.force_disable_merge: 177 | return 178 | 179 | if mode: 180 | if self.merge_weights and self.merged: 181 | # Make sure that the weights are not merged 182 | self.weight.data -= torch.mean((self.lora_B @ self.lora_A).view(self.view_shape), dim=-2, keepdim=True) * self.scaling 183 | self.merged = False 184 | else: 185 | if self.merge_weights and not self.merged: 186 | # Merge the weights and mark it 187 | self.weight.data += torch.mean((self.lora_B @ self.lora_A).view(self.view_shape), dim=-2, keepdim=True) * self.scaling 188 | self.merged = True 189 | 190 | def forward(self, x: torch.Tensor): 191 | if self.r > 0 and not self.merged: 192 | return F.conv3d( 193 | x, 194 | self.weight + torch.mean((self.lora_B @ self.lora_A).view(self.view_shape), dim=-2, keepdim=True) * \ 195 | self.scaling, self.bias, self.stride, self.padding, self.dilation, self.groups 196 | ) 197 | return nn.Conv3d.forward(self, x) 198 | 199 | def create_lora_linear(child_module, r, dropout=0, bias=False, scale=0): 200 | return loralb.Linear( 201 | child_module.in_features, 202 | child_module.out_features, 203 | merge_weights=False, 204 | bias=bias, 205 | lora_dropout=dropout, 206 | lora_alpha=r, 207 | r=r 208 | ) 209 | return lora_linear 210 | 211 | def create_lora_conv(child_module, r, dropout=0, bias=False, rescale=False, scale=0): 212 | return Conv2d( 213 | child_module.in_channels, 214 | child_module.out_channels, 215 | kernel_size=child_module.kernel_size[0], 216 | padding=child_module.padding, 217 | stride=child_module.stride, 218 | merge_weights=False, 219 | bias=bias, 220 | lora_dropout=dropout, 221 | lora_alpha=r, 222 | r=r, 223 | ) 224 | return lora_conv 225 | 226 | def create_lora_conv3d(child_module, r, dropout=0, bias=False, rescale=False, scale=0): 227 | return Conv3d( 228 | child_module.in_channels, 229 | child_module.out_channels, 230 | kernel_size=child_module.kernel_size[0], 231 | padding=child_module.padding, 232 | stride=child_module.stride, 233 | merge_weights=False, 234 | bias=bias, 235 | lora_dropout=dropout, 236 | lora_alpha=r, 237 | r=r, 238 | ) 239 | return lora_conv 240 | 241 | def create_lora_emb(child_module, r): 242 | return loralb.Embedding( 243 | child_module.num_embeddings, 244 | child_module.embedding_dim, 245 | merge_weights=False, 246 | lora_alpha=r, 247 | r=r 248 | ) 249 | 250 | def activate_lora_train(model, bias): 251 | def unfreeze(): 252 | print(model.__class__.__name__ + " LoRA set for training.") 253 | return loralb.mark_only_lora_as_trainable(model, bias=bias) 254 | 255 | return unfreeze 256 | 257 | def add_lora_to( 258 | model, 259 | target_module=UNET_REPLACE, 260 | search_class=[torch.nn.Linear], 261 | r=32, 262 | dropout=0, 263 | lora_bias='none' 264 | ): 265 | for module, name, child_module in find_modules( 266 | model, 267 | ancestor_class=target_module, 268 | search_class=search_class 269 | ): 270 | bias = hasattr(child_module, "bias") 271 | 272 | # Check if child module of the model has bias. 273 | if bias: 274 | if child_module.bias is None: 275 | bias = False 276 | 277 | # Check if the child module of the model is type Linear or Conv2d. 278 | if isinstance(child_module, torch.nn.Linear): 279 | l = create_lora_linear(child_module, r, dropout, bias=bias) 280 | 281 | if isinstance(child_module, torch.nn.Conv2d): 282 | l = create_lora_conv(child_module, r, dropout, bias=bias) 283 | 284 | if isinstance(child_module, torch.nn.Conv3d): 285 | l = create_lora_conv3d(child_module, r, dropout, bias=bias) 286 | 287 | if isinstance(child_module, torch.nn.Embedding): 288 | l = create_lora_emb(child_module, r) 289 | 290 | # If the model has bias and we wish to add it, use the child_modules in place 291 | if bias: 292 | l.bias = child_module.bias 293 | 294 | # Assign the frozen weight of model's Linear or Conv2d to the LoRA model. 295 | l.weight = child_module.weight 296 | 297 | # Replace the new LoRA model with the model's Linear or Conv2d module. 298 | module._modules[name] = l 299 | 300 | 301 | # Unfreeze only the newly added LoRA weights, but keep the model frozen. 302 | return activate_lora_train(model, lora_bias) 303 | 304 | def save_lora( 305 | unet=None, 306 | text_encoder=None, 307 | save_text_weights=False, 308 | output_dir="output", 309 | lora_filename="lora.safetensors", 310 | lora_bias='none', 311 | save_for_webui=True, 312 | only_webui=False, 313 | metadata=None, 314 | unet_dict_converter=None, 315 | text_dict_converter=None 316 | ): 317 | 318 | if not only_webui: 319 | # Create directory for the full LoRA weights. 320 | trainable_weights_dir = f"{output_dir}/full_weights" 321 | lora_out_file_full_weight = f"{trainable_weights_dir}/{lora_filename}" 322 | os.makedirs(trainable_weights_dir, exist_ok=True) 323 | 324 | ext = '.safetensors' 325 | # Create LoRA out filename. 326 | lora_out_file = f"{output_dir}/webui_{lora_filename}{ext}" 327 | 328 | if not only_webui: 329 | save_path_full_weights = lora_out_file_full_weight + ext 330 | 331 | save_path = lora_out_file 332 | 333 | if not only_webui: 334 | for i, model in enumerate([unet, text_encoder]): 335 | if save_text_weights and i == 1: 336 | non_webui_weights = save_path_full_weights.replace(ext, f"_text_encoder{ext}") 337 | 338 | else: 339 | non_webui_weights = save_path_full_weights.replace(ext, f"_unet{ext}") 340 | 341 | # Load only the LoRAs from the state dict. 342 | lora_dict = loralb.lora_state_dict(model, bias=lora_bias) 343 | 344 | # Save the models as fp32. This ensures we can finetune again without having to upcast. 345 | save_file(lora_dict, non_webui_weights) 346 | 347 | if save_for_webui: 348 | # Convert the keys to compvis model and webui 349 | unet_lora_dict = loralb.lora_state_dict(unet, bias=lora_bias) 350 | lora_dict_fp16 = unet_dict_converter(unet_lora_dict, strict_mapping=True) 351 | 352 | if save_text_weights: 353 | text_encoder_dict = loralb.lora_state_dict(text_encoder, bias=lora_bias) 354 | lora_dict_text_fp16 = text_dict_converter(text_encoder_dict) 355 | 356 | # Update the Unet dict to include text keys. 357 | lora_dict_fp16.update(lora_dict_text_fp16) 358 | 359 | # Cast tensors to fp16. It's assumed we won't be finetuning these. 360 | for k, v in lora_dict_fp16.items(): 361 | lora_dict_fp16[k] = v.to(dtype=torch.float16) 362 | 363 | save_file( 364 | lora_dict_fp16, 365 | save_path, 366 | metadata=metadata 367 | ) 368 | 369 | def load_lora(model, lora_path: str): 370 | try: 371 | if os.path.exists(lora_path): 372 | lora_dict = load_file(lora_path) 373 | model.load_state_dict(lora_dict, strict=False) 374 | 375 | except Exception as e: 376 | print(f"Could not load your lora file: {e}") 377 | 378 | def set_mode(model, train=False): 379 | for n, m in model.named_modules(): 380 | is_lora = hasattr(m, 'merged') 381 | if is_lora: 382 | m.train(train) 383 | 384 | def set_mode_group(models, train): 385 | for model in models: 386 | set_mode(model, train) 387 | model.train(train) 388 | -------------------------------------------------------------------------------- /svd_video2video_examples/barbie_input.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/barbie_input.mp4 -------------------------------------------------------------------------------- /svd_video2video_examples/barbie_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/barbie_mask.png -------------------------------------------------------------------------------- /svd_video2video_examples/barbie_output.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/barbie_output.mp4 -------------------------------------------------------------------------------- /svd_video2video_examples/car_input.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_input.mp4 -------------------------------------------------------------------------------- /svd_video2video_examples/car_mask_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_mask_1.png -------------------------------------------------------------------------------- /svd_video2video_examples/car_mask_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_mask_2.png -------------------------------------------------------------------------------- /svd_video2video_examples/car_output_1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_output_1.mp4 -------------------------------------------------------------------------------- /svd_video2video_examples/car_output_2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_output_2.mp4 -------------------------------------------------------------------------------- /svd_video2video_examples/windmill_input.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/windmill_input.mp4 -------------------------------------------------------------------------------- /svd_video2video_examples/windmill_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/windmill_mask.png -------------------------------------------------------------------------------- /svd_video2video_examples/windmill_output.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/windmill_output.mp4 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/utils/__init__.py -------------------------------------------------------------------------------- /utils/bucketing.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | def min_res(size, min_size): return 192 if size < 192 else size 4 | 5 | def up_down_bucket(m_size, in_size, direction): 6 | if direction == 'down': return abs(int(m_size - in_size)) 7 | if direction == 'up': return abs(int(m_size + in_size)) 8 | 9 | def get_bucket_sizes(size, direction: 'down', min_size): 10 | multipliers = [64, 128] 11 | for i, m in enumerate(multipliers): 12 | res = up_down_bucket(m, size, direction) 13 | multipliers[i] = min_res(res, min_size=min_size) 14 | return multipliers 15 | 16 | def closest_bucket(m_size, size, direction, min_size): 17 | lst = get_bucket_sizes(m_size, direction, min_size) 18 | return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))] 19 | 20 | def resolve_bucket(i,h,w): return (i / (h / w)) 21 | 22 | def sensible_buckets(m_width, m_height, w, h, min_size=192): 23 | if h > w: 24 | w = resolve_bucket(m_width, h, w) 25 | w = closest_bucket(m_width, w, 'down', min_size=min_size) 26 | return w, m_height 27 | if h < w: 28 | h = resolve_bucket(m_height, w, h) 29 | h = closest_bucket(m_height, h, 'down', min_size=min_size) 30 | return m_width, h 31 | 32 | return m_width, m_height -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | from PIL import Image 4 | import torch 5 | import random 6 | import numpy as np 7 | import torchvision.transforms as T 8 | from einops import rearrange, repeat 9 | import imageio 10 | import sys 11 | 12 | def tensor_to_vae_latent(t, vae): 13 | video_length = t.shape[1] 14 | 15 | t = rearrange(t, "b f c h w -> (b f) c h w") 16 | latents = vae.encode(t).latent_dist.mode() 17 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) 18 | latents = latents * 0.18215 19 | 20 | return latents 21 | 22 | def DDPM_forward(x0, step, num_frames, scheduler): 23 | device = x0.device 24 | t = scheduler.timesteps[-1] 25 | xt = repeat(x0, 'b c 1 h w -> b c f h w', f = num_frames) 26 | 27 | eps = torch.randn_like(xt) 28 | alpha_vec = torch.prod(scheduler.alphas[t:]) 29 | xt = torch.sqrt(alpha_vec) * xt + torch.sqrt(1-alpha_vec) * eps 30 | return xt, None 31 | 32 | def DDPM_forward_timesteps(x0, step, num_frames, scheduler): 33 | '''larger step -> smaller t -> smaller alphas[t:] -> smaller xt -> smaller x0''' 34 | 35 | device = x0.device 36 | # timesteps are reversed 37 | timesteps = scheduler.timesteps[len(scheduler.timesteps)-step:] 38 | t = timesteps[0] 39 | 40 | if x0.shape[2] == 1: 41 | xt = repeat(x0, 'b c 1 h w -> b c f h w', f = num_frames) 42 | else: 43 | xt = x0 44 | noise = torch.randn(xt.shape, dtype=xt.dtype, device=device) 45 | # t to tensor of batch size 46 | t = torch.tensor([t]*xt.shape[0], device=device) 47 | xt = scheduler.add_noise(xt, noise, t) 48 | return xt, timesteps 49 | 50 | def DDPM_forward_mask(x0, step, num_frames, scheduler, mask): 51 | '''larger step -> smaller t -> smaller alphas[t:] -> smaller xt -> smaller x0''' 52 | device = x0.device 53 | dtype = x0.dtype 54 | b, c, f, h, w = x0.shape 55 | 56 | move_xt, timesteps = DDPM_forward_timesteps(x0, step, num_frames, scheduler) 57 | mask = T.ToTensor()(mask).to(dtype).to(device) 58 | mask = T.Resize([h, w], antialias=False)(mask) 59 | mask = rearrange(mask, 'b h w -> b 1 1 h w') 60 | freeze_xt = repeat(x0, 'b c 1 h w -> b c f h w', f = num_frames) 61 | initial = freeze_xt * (1-mask) + move_xt * mask 62 | return initial, timesteps 63 | 64 | def read_video(video_path, frame_number=-1): 65 | # Open the video file 66 | cap = cv2.VideoCapture(video_path) 67 | count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 68 | if frame_number == -1: 69 | frame_number = count 70 | else: 71 | frame_number = min(frame_number, count) 72 | frames = [] 73 | for i in range(frame_number): 74 | ret, ref_frame = cap.read() 75 | ref_frame = cv2.cvtColor(ref_frame, cv2.COLOR_BGR2RGB) 76 | if not ret: 77 | raise ValueError("Failed to read video file") 78 | frames.append(ref_frame) 79 | return frames 80 | 81 | def get_full_white_area_mask(frames): 82 | ref_frame = frames[0] 83 | ref_gray = cv2.cvtColor(ref_frame, cv2.COLOR_BGR2GRAY) 84 | total_mask = np.ones_like(ref_gray) * 255 85 | 86 | return total_mask 87 | 88 | def get_moved_area_mask(frames, move_th=5, th=-1): 89 | ref_frame = frames[0] 90 | # Convert the reference frame to gray 91 | ref_gray = cv2.cvtColor(ref_frame, cv2.COLOR_BGR2GRAY) 92 | prev_gray = ref_gray 93 | # Initialize the total accumulated motion mask 94 | total_mask = np.zeros_like(ref_gray) 95 | 96 | # Iterate through the video frames 97 | for i in range(1, len(frames)): 98 | frame = frames[i] 99 | # Convert the frame to gray 100 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 101 | 102 | # Compute the absolute difference between the reference frame and the current frame 103 | diff = cv2.absdiff(ref_gray, gray) 104 | #diff += cv2.absdiff(prev_gray, gray) 105 | 106 | # Apply a threshold to obtain a binary image 107 | ret, mask = cv2.threshold(diff, move_th, 255, cv2.THRESH_BINARY) 108 | 109 | # Accumulate the mask 110 | total_mask = cv2.bitwise_or(total_mask, mask) 111 | 112 | # Update the reference frame 113 | prev_gray = gray 114 | 115 | contours, _ = cv2.findContours(total_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 116 | rects = [] 117 | ref_mask = np.zeros_like(ref_gray) 118 | ref_mask = cv2.drawContours(ref_mask, contours, -1, (255, 255, 255), -1) 119 | for cnt in contours: 120 | cur_rec = cv2.boundingRect(cnt) 121 | rects.append(cur_rec) 122 | 123 | #rects = merge_overlapping_rectangles(rects) 124 | mask = np.zeros_like(ref_gray) 125 | if th < 0: 126 | h, w = mask.shape 127 | th = int(h*w*0.005) 128 | for rect in rects: 129 | x, y, w, h = rect 130 | if w*h < th: 131 | continue 132 | #ref_frame = cv2.rectangle(ref_frame, (x, y), (x+w, y+h), (0, 255, 0), 2) 133 | mask[y:y+h, x:x+w] = 255 134 | return mask 135 | 136 | def calculate_motion_precision(frames, mask): 137 | moved_mask = get_moved_area_mask(frames, move_th=20, th=0) 138 | moved = moved_mask == 255 139 | gt = mask == 255 140 | precision = np.sum(moved & gt) / np.sum(moved) 141 | return precision 142 | 143 | def check_overlap(rect1, rect2): 144 | # Calculate the coordinates of the edges of the rectangles 145 | rect1_left = rect1[0] 146 | rect1_right = rect1[0] + rect1[2] 147 | rect1_top = rect1[1] 148 | rect1_bottom = rect1[1] + rect1[3] 149 | 150 | rect2_left = rect2[0] 151 | rect2_right = rect2[0] + rect2[2] 152 | rect2_top = rect2[1] 153 | rect2_bottom = rect2[1] + rect2[3] 154 | 155 | # Check if the rectangles overlap 156 | if (rect2_left >= rect1_right or rect2_right <= rect1_left or 157 | rect2_top >= rect1_bottom or rect2_bottom <= rect1_top): 158 | return False 159 | else: 160 | return True 161 | 162 | def merge_rects(rect1, rect2): 163 | left = min(rect1[0], rect2[0]) 164 | top = min(rect1[1], rect2[1]) 165 | right = max(rect1[0]+rect1[2], rect2[0]+rect2[2]) 166 | bottom = max(rect1[1]+rect1[3], rect2[1]+rect2[3]) 167 | width = right - left 168 | height = bottom - top 169 | return (left, top, width, height) 170 | 171 | def merge_overlapping_rectangles(rectangles): 172 | # Sort the rectangles based on their left coordinate 173 | sorted_rectangles = sorted(rectangles, key=lambda x: x[0]) 174 | 175 | # Initialize an empty list to store the merged rectangles 176 | merged_rectangles = [] 177 | 178 | # Iterate through the sorted rectangles and merge them 179 | for rect in sorted_rectangles: 180 | if not merged_rectangles: 181 | # If the merged rectangles list is empty, add the first rectangle to it 182 | merged_rectangles.append(rect) 183 | else: 184 | # Get the last merged rectangle 185 | last_merged = merged_rectangles[-1] 186 | 187 | # Check if the current rectangle overlaps with the last merged rectangle 188 | if last_merged[0] + last_merged[2] >= rect[0]: 189 | # Merge the rectangles if they overlap 190 | merged_rectangles[-1] = ( 191 | min(last_merged[0], rect[0]), 192 | min(last_merged[1], rect[1]), 193 | max(last_merged[0] + last_merged[2], rect[0] + rect[2]) - min(last_merged[0], rect[0]), 194 | max(last_merged[1] + last_merged[3], rect[1] + rect[3]) - min(last_merged[1], rect[1]) 195 | ) 196 | else: 197 | # Add the current rectangle to the merged rectangles list if they don't overlap 198 | merged_rectangles.append(rect) 199 | 200 | return merged_rectangles 201 | 202 | def generate_random_mask(image): 203 | # Create a blank mask with the same size as the image 204 | b, c , h, w = image.shape 205 | mask = np.zeros([b, h, w], dtype=np.uint8) 206 | 207 | # Generate random coordinates for the mask 208 | num_points = np.random.randint(3, 10) # Randomly choose the number of points to generate 209 | points = np.random.randint(0, min(h, w), size=(num_points, 2)) # Randomly generate the points 210 | # Draw a filled polygon on the mask using the random points 211 | for i in range(b): 212 | width = random.randint(w//4, w) 213 | height = random.randint(h//4, h) 214 | x = random.randint(0, w-width) 215 | y = random.randint(0, h-height) 216 | points=np.array([[x, y], [x+width, y], [x+width, y+height], [x, y+height]]) 217 | mask[i] = cv2.fillPoly(mask[i], [points], 255) 218 | 219 | # Apply the mask to the image 220 | #masked_image = cv2.bitwise_and(image, image, mask=mask) 221 | return mask 222 | 223 | def generate_center_mask(image): 224 | # Create a blank mask with the same size as the image 225 | b, c , h, w = image.shape 226 | mask = np.zeros([b, h, w], dtype=np.uint8) 227 | 228 | # Generate random coordinates for the mask 229 | for i in range(b): 230 | width = int(w/10) 231 | height = int(h/10) 232 | mask[i][height:-height,width:-width] = 255 233 | # Apply the mask to the image 234 | #masked_image = cv2.bitwise_and(image, image, mask=mask) 235 | return mask 236 | 237 | def read_mask(json_path, label=["mask"]): 238 | j = json.load(open(json_path)) 239 | if type(label) != list: 240 | labels = [label] 241 | height = j['imageHeight'] 242 | width = j['imageWidth'] 243 | mask = np.zeros([height, width], dtype=np.uint8) 244 | for shape in j['shapes']: 245 | if shape['label'] in label: 246 | x1, y1 = shape['points'][0] 247 | x2, y2 = shape['points'][1] 248 | mask[int(y1):int(y2), int(x1):int(x2)] = 255 249 | return mask 250 | 251 | 252 | def slerp(z1, z2, alpha): 253 | theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2))) 254 | return ( 255 | torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1 256 | + torch.sin(alpha * theta) / torch.sin(theta) * z2 257 | ) 258 | 259 | def _detect_edges(lum: np.ndarray, kernel_size=5) -> np.ndarray: 260 | """Detect edges using the luma channel of a frame. 261 | 262 | Arguments: 263 | lum: 2D 8-bit image representing the luma channel of a frame. 264 | 265 | Returns: 266 | 2D 8-bit image of the same size as the input, where pixels with values of 255 267 | represent edges, and all other pixels are 0. 268 | """ 269 | # Initialize kernel. 270 | #kernel_size = _estimated_kernel_size(lum.shape[1], lum.shape[0]) 271 | _kernel = np.ones((kernel_size, kernel_size), np.uint8) 272 | 273 | # Estimate levels for thresholding. 274 | # TODO(0.6.3): Add config file entries for sigma, aperture/kernel size, etc. 275 | sigma: float = 1.0 / 3.0 276 | median = np.median(lum) 277 | low = int(max(0, (1.0 - sigma) * median)) 278 | high = int(min(255, (1.0 + sigma) * median)) 279 | 280 | # Calculate edges using Canny algorithm, and reduce noise by dilating the edges. 281 | # This increases edge overlap leading to improved robustness against noise and slow 282 | # camera movement. Note that very large kernel sizes can negatively affect accuracy. 283 | edges = cv2.Canny(lum, low, high) 284 | return cv2.dilate(edges, _kernel) 285 | 286 | 287 | def _mean_pixel_distance(left: np.ndarray, right: np.ndarray) -> float: 288 | """Return the mean average distance in pixel values between `left` and `right`. 289 | Both `left and `right` should be 2 dimensional 8-bit images of the same shape. 290 | """ 291 | assert len(left.shape) == 2 and len(right.shape) == 2 292 | assert left.shape == right.shape 293 | num_pixels: float = float(left.shape[0] * left.shape[1]) 294 | return (np.sum(np.abs(left.astype(np.int32) - right.astype(np.int32))) / num_pixels) 295 | 296 | def calculate_latent_motion_score(latents): 297 | #latents b, c f, h, w 298 | diff=torch.abs(latents[:,:,1:]-latents[:,:,:-1]) 299 | motion_score = torch.sum(torch.mean(diff, dim=[2,3,4]), dim=1) * 10 300 | return motion_score 301 | 302 | def motion_mask_loss(latents, mask): 303 | diff = torch.abs(latents[:,:,1:] - latents[:,:,:-1]) 304 | loss = torch.sum(torch.mean(diff * (1-mask), dim=[2,3,4]), dim=1) 305 | return loss 306 | 307 | def calculate_motion_score(frame_imgs, calculate_edges=False, color="RGB") -> float: 308 | # Convert image into HSV colorspace. 309 | _last_frame = None 310 | 311 | _weights = [1.0, 1.0, 1.0, 0.0] 312 | score = 0 313 | for frame_img in frame_imgs: 314 | if color == "RGB": 315 | hue, sat, lum = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_RGB2HSV)) 316 | else: 317 | hue, sat, lum = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_BGR2HSV)) 318 | # Performance: Only calculate edges if we have to. 319 | edges = _detect_edges(lum) if calculate_edges else None 320 | if _last_frame == None: 321 | _last_frame = (hue, sat, lum, edges) 322 | continue 323 | 324 | score_components = [ 325 | _mean_pixel_distance(hue, _last_frame[0]), 326 | _mean_pixel_distance(sat, _last_frame[1]), 327 | _mean_pixel_distance(lum, _last_frame[2]), 328 | 0.0 if edges is None else _mean_pixel_distance(edges, _last_frame[3]), 329 | ] 330 | 331 | frame_score: float = ( 332 | sum(component * weight for (component, weight) in zip(score_components, _weights)) 333 | / sum(abs(weight) for weight in _weights)) 334 | score += frame_score 335 | _last_frame = (hue, sat, lum, edges) 336 | 337 | return round(score/(len(frame_imgs)-1) * 10) 338 | 339 | if __name__ == "__main__": 340 | 341 | # Example usage 342 | video_paths = [ 343 | "/data/video/animate2/Bleach.Sennen.Kessen.Hen.S01E01.2022.1080p.WEB-DL.x264.AAC-DDHDTV-Scene-002.mp4", 344 | "/data/video/animate2/Evangelion.3.0.1.01.Thrice.Upon.A.Time.2021.BLURAY.720p.BluRay.x264.AAC-[YTS.MX]-Scene-0780.mp4", 345 | "/data/video/animate2/[GM-Team][国漫][永生 第2季][IMMORTALITY Ⅱ][2023][09][AVC][GB][1080P]-Scene-180.mp4", 346 | "/data/video/animate2/[orion origin] Legend of the Galactic Heroes Die Neue These [07] [WebRip 1080p] [H265 AAC] [GB]-Scene-048.mp4", 347 | "/data/video/MSRVTT/videos/all/video33.mp4", 348 | "/webvid/webvid/data/videos/000001_000050/1066692580.mp4", 349 | "/webvid/webvid/data/videos/000001_000050/1066685533.mp4", 350 | "/webvid/webvid/data/videos/000001_000050/1066685548.mp4", 351 | "/webvid/webvid/data/videos/000001_000050/1066676380.mp4", 352 | "/webvid/webvid/data/videos/000001_000050/1066676377.mp4", 353 | ] 354 | for i, video_path in enumerate(video_paths[:5]): 355 | frames = read_video(video_path, 200)[::3] 356 | if sys.argv[1] == 'test_mask': 357 | mask = get_moved_area_mask(frames) 358 | Image.fromarray(mask).save(f"output/mask/{i}.jpg") 359 | imageio.mimwrite(f"output/mask/{i}.gif", frames, duration=125, loop=0) 360 | elif sys.argv[1] == 'test_motion': 361 | for r in range(0, len(frames), 16): 362 | video_frames = frames[r:r+16] 363 | video_frames = [cv2.resize(f, (512, 512)) for f in video_frames] 364 | score = calculate_motion_score(video_frames, calculate_edges=False, color="BGR") 365 | imageio.mimwrite(f"output/example_video/{i}_{r}_{score}.mp4", video_frames, fps=8) 366 | elif sys.argv[1] == 'to_gif': 367 | imageio.mimwrite(f"output/example_video/{i}.gif", frames, duration=125, loop=0) 368 | -------------------------------------------------------------------------------- /utils/convert_diffusers_to_original_ms_text_to_video.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | 5 | import argparse 6 | import os.path as osp 7 | import re 8 | 9 | import torch 10 | from safetensors.torch import load_file, save_file 11 | 12 | # =================# 13 | # UNet Conversion # 14 | # =================# 15 | 16 | print ('Initializing the conversion map') 17 | 18 | unet_conversion_map = [ 19 | # (ModelScope, HF Diffusers) 20 | 21 | # from Vanilla ModelScope/StableDiffusion 22 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 23 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 24 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 25 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 26 | 27 | 28 | # from Vanilla ModelScope/StableDiffusion 29 | ("input_blocks.0.0.weight", "conv_in.weight"), 30 | ("input_blocks.0.0.bias", "conv_in.bias"), 31 | 32 | 33 | # from Vanilla ModelScope/StableDiffusion 34 | ("out.0.weight", "conv_norm_out.weight"), 35 | ("out.0.bias", "conv_norm_out.bias"), 36 | ("out.2.weight", "conv_out.weight"), 37 | ("out.2.bias", "conv_out.bias"), 38 | ] 39 | 40 | unet_conversion_map_resnet = [ 41 | # (ModelScope, HF Diffusers) 42 | 43 | # SD 44 | ("in_layers.0", "norm1"), 45 | ("in_layers.2", "conv1"), 46 | ("out_layers.0", "norm2"), 47 | ("out_layers.3", "conv2"), 48 | ("emb_layers.1", "time_emb_proj"), 49 | ("skip_connection", "conv_shortcut"), 50 | 51 | # MS 52 | #("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha 53 | ] 54 | 55 | unet_conversion_map_layer = [] 56 | 57 | # Convert input TemporalTransformer 58 | unet_conversion_map_layer.append(('input_blocks.0.1', 'transformer_in')) 59 | 60 | # Reference for the default settings 61 | 62 | # "model_cfg": { 63 | # "unet_in_dim": 4, 64 | # "unet_dim": 320, 65 | # "unet_y_dim": 768, 66 | # "unet_context_dim": 1024, 67 | # "unet_out_dim": 4, 68 | # "unet_dim_mult": [1, 2, 4, 4], 69 | # "unet_num_heads": 8, 70 | # "unet_head_dim": 64, 71 | # "unet_res_blocks": 2, 72 | # "unet_attn_scales": [1, 0.5, 0.25], 73 | # "unet_dropout": 0.1, 74 | # "temporal_attention": "True", 75 | # "num_timesteps": 1000, 76 | # "mean_type": "eps", 77 | # "var_type": "fixed_small", 78 | # "loss_type": "mse" 79 | # } 80 | 81 | # hardcoded number of downblocks and resnets/attentions... 82 | # would need smarter logic for other networks. 83 | for i in range(4): 84 | # loop over downblocks/upblocks 85 | 86 | for j in range(2): 87 | # loop over resnets/attentions for downblocks 88 | 89 | # Spacial SD stuff 90 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 91 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 92 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 93 | 94 | if i < 3: 95 | # no attention layers in down_blocks.3 96 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 97 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 98 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 99 | 100 | # Temporal MS stuff 101 | hf_down_res_prefix = f"down_blocks.{i}.temp_convs.{j}." 102 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0.temopral_conv." 103 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 104 | 105 | if i < 3: 106 | # no attention layers in down_blocks.3 107 | hf_down_atn_prefix = f"down_blocks.{i}.temp_attentions.{j}." 108 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.2." 109 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 110 | 111 | for j in range(3): 112 | # loop over resnets/attentions for upblocks 113 | 114 | # Spacial SD stuff 115 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 116 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 117 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 118 | 119 | if i > 0: 120 | # no attention layers in up_blocks.0 121 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 122 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 123 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 124 | 125 | # loop over resnets/attentions for upblocks 126 | hf_up_res_prefix = f"up_blocks.{i}.temp_convs.{j}." 127 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0.temopral_conv." 128 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 129 | 130 | if i > 0: 131 | # no attention layers in up_blocks.0 132 | hf_up_atn_prefix = f"up_blocks.{i}.temp_attentions.{j}." 133 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.2." 134 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 135 | 136 | # Up/Downsamplers are 2D, so don't need to touch them 137 | if i < 3: 138 | # no downsample in down_blocks.3 139 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 140 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.op." 141 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 142 | 143 | # no upsample in up_blocks.3 144 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 145 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 3}." 146 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 147 | 148 | 149 | # Handle the middle block 150 | 151 | # Spacial 152 | hf_mid_atn_prefix = "mid_block.attentions.0." 153 | sd_mid_atn_prefix = "middle_block.1." 154 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 155 | 156 | for j in range(2): 157 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 158 | sd_mid_res_prefix = f"middle_block.{3*j}." 159 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 160 | 161 | # Temporal 162 | hf_mid_atn_prefix = "mid_block.temp_attentions.0." 163 | sd_mid_atn_prefix = "middle_block.2." 164 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 165 | 166 | for j in range(2): 167 | hf_mid_res_prefix = f"mid_block.temp_convs.{j}." 168 | sd_mid_res_prefix = f"middle_block.{3*j}.temopral_conv." 169 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 170 | 171 | # The pipeline 172 | def convert_unet_state_dict(unet_state_dict, strict_mapping=False): 173 | print ('Converting the UNET') 174 | # buyer beware: this is a *brittle* function, 175 | # and correct output requires that all of these pieces interact in 176 | # the exact order in which I have arranged them. 177 | mapping = {k: k for k in unet_state_dict.keys()} 178 | 179 | for sd_name, hf_name in unet_conversion_map: 180 | if strict_mapping: 181 | if hf_name in mapping: 182 | mapping[hf_name] = sd_name 183 | else: 184 | mapping[hf_name] = sd_name 185 | for k, v in mapping.items(): 186 | if "resnets" in k: 187 | for sd_part, hf_part in unet_conversion_map_resnet: 188 | v = v.replace(hf_part, sd_part) 189 | mapping[k] = v 190 | # elif "temp_convs" in k: 191 | # for sd_part, hf_part in unet_conversion_map_resnet: 192 | # v = v.replace(hf_part, sd_part) 193 | # mapping[k] = v 194 | for k, v in mapping.items(): 195 | for sd_part, hf_part in unet_conversion_map_layer: 196 | v = v.replace(hf_part, sd_part) 197 | mapping[k] = v 198 | 199 | 200 | # there must be a pattern, but I don't want to bother atm 201 | do_not_unsqueeze = [f'output_blocks.{i}.1.proj_out.weight' for i in range(3, 12)] + [f'output_blocks.{i}.1.proj_in.weight' for i in range(3, 12)] + ['middle_block.1.proj_in.weight', 'middle_block.1.proj_out.weight'] + [f'input_blocks.{i}.1.proj_out.weight' for i in [1, 2, 4, 5, 7, 8]] + [f'input_blocks.{i}.1.proj_in.weight' for i in [1, 2, 4, 5, 7, 8]] 202 | print (do_not_unsqueeze) 203 | 204 | new_state_dict = {v: (unet_state_dict[k].unsqueeze(-1) if ('proj_' in k and ('bias' not in k) and (k not in do_not_unsqueeze)) else unet_state_dict[k]) for k, v in mapping.items()} 205 | # HACK: idk why the hell it does not work with list comprehension 206 | for k, v in new_state_dict.items(): 207 | has_k = False 208 | for n in do_not_unsqueeze: 209 | if k == n: 210 | has_k = True 211 | 212 | if has_k: 213 | v = v.squeeze(-1) 214 | new_state_dict[k] = v 215 | 216 | return new_state_dict 217 | 218 | # TODO: VAE conversion. We doesn't train it in the most cases, but may be handy for the future --kabachuha 219 | 220 | # =========================# 221 | # Text Encoder Conversion # 222 | # =========================# 223 | 224 | # IT IS THE SAME CLIP ENCODER, SO JUST COPYPASTING IT --kabachuha 225 | 226 | # =========================# 227 | # Text Encoder Conversion # 228 | # =========================# 229 | 230 | 231 | textenc_conversion_lst = [ 232 | # (stable-diffusion, HF Diffusers) 233 | ("resblocks.", "text_model.encoder.layers."), 234 | ("ln_1", "layer_norm1"), 235 | ("ln_2", "layer_norm2"), 236 | (".c_fc.", ".fc1."), 237 | (".c_proj.", ".fc2."), 238 | (".attn", ".self_attn"), 239 | ("ln_final.", "transformer.text_model.final_layer_norm."), 240 | ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), 241 | ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), 242 | ] 243 | protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} 244 | textenc_pattern = re.compile("|".join(protected.keys())) 245 | 246 | # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp 247 | code2idx = {"q": 0, "k": 1, "v": 2} 248 | 249 | 250 | def convert_text_enc_state_dict_v20(text_enc_dict): 251 | #print ('Converting the text encoder') 252 | new_state_dict = {} 253 | capture_qkv_weight = {} 254 | capture_qkv_bias = {} 255 | for k, v in text_enc_dict.items(): 256 | if ( 257 | k.endswith(".self_attn.q_proj.weight") 258 | or k.endswith(".self_attn.k_proj.weight") 259 | or k.endswith(".self_attn.v_proj.weight") 260 | ): 261 | k_pre = k[: -len(".q_proj.weight")] 262 | k_code = k[-len("q_proj.weight")] 263 | if k_pre not in capture_qkv_weight: 264 | capture_qkv_weight[k_pre] = [None, None, None] 265 | capture_qkv_weight[k_pre][code2idx[k_code]] = v 266 | continue 267 | 268 | if ( 269 | k.endswith(".self_attn.q_proj.bias") 270 | or k.endswith(".self_attn.k_proj.bias") 271 | or k.endswith(".self_attn.v_proj.bias") 272 | ): 273 | k_pre = k[: -len(".q_proj.bias")] 274 | k_code = k[-len("q_proj.bias")] 275 | if k_pre not in capture_qkv_bias: 276 | capture_qkv_bias[k_pre] = [None, None, None] 277 | capture_qkv_bias[k_pre][code2idx[k_code]] = v 278 | continue 279 | 280 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) 281 | new_state_dict[relabelled_key] = v 282 | 283 | for k_pre, tensors in capture_qkv_weight.items(): 284 | if None in tensors: 285 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") 286 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) 287 | new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) 288 | 289 | for k_pre, tensors in capture_qkv_bias.items(): 290 | if None in tensors: 291 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") 292 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) 293 | new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) 294 | 295 | return new_state_dict 296 | 297 | 298 | def convert_text_enc_state_dict(text_enc_dict): 299 | return text_enc_dict 300 | 301 | textenc_conversion_lst = [ 302 | # (stable-diffusion, HF Diffusers) 303 | ("resblocks.", "text_model.encoder.layers."), 304 | ("ln_1", "layer_norm1"), 305 | ("ln_2", "layer_norm2"), 306 | (".c_fc.", ".fc1."), 307 | (".c_proj.", ".fc2."), 308 | (".attn", ".self_attn"), 309 | ("ln_final.", "transformer.text_model.final_layer_norm."), 310 | ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), 311 | ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), 312 | ] 313 | protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} 314 | textenc_pattern = re.compile("|".join(protected.keys())) 315 | 316 | # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp 317 | code2idx = {"q": 0, "k": 1, "v": 2} 318 | 319 | 320 | def convert_text_enc_state_dict_v20(text_enc_dict): 321 | new_state_dict = {} 322 | capture_qkv_weight = {} 323 | capture_qkv_bias = {} 324 | for k, v in text_enc_dict.items(): 325 | if ( 326 | k.endswith(".self_attn.q_proj.weight") 327 | or k.endswith(".self_attn.k_proj.weight") 328 | or k.endswith(".self_attn.v_proj.weight") 329 | ): 330 | k_pre = k[: -len(".q_proj.weight")] 331 | k_code = k[-len("q_proj.weight")] 332 | if k_pre not in capture_qkv_weight: 333 | capture_qkv_weight[k_pre] = [None, None, None] 334 | capture_qkv_weight[k_pre][code2idx[k_code]] = v 335 | continue 336 | 337 | if ( 338 | k.endswith(".self_attn.q_proj.bias") 339 | or k.endswith(".self_attn.k_proj.bias") 340 | or k.endswith(".self_attn.v_proj.bias") 341 | ): 342 | k_pre = k[: -len(".q_proj.bias")] 343 | k_code = k[-len("q_proj.bias")] 344 | if k_pre not in capture_qkv_bias: 345 | capture_qkv_bias[k_pre] = [None, None, None] 346 | capture_qkv_bias[k_pre][code2idx[k_code]] = v 347 | continue 348 | 349 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) 350 | new_state_dict[relabelled_key] = v 351 | 352 | for k_pre, tensors in capture_qkv_weight.items(): 353 | if None in tensors: 354 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") 355 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) 356 | new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) 357 | 358 | for k_pre, tensors in capture_qkv_bias.items(): 359 | if None in tensors: 360 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") 361 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) 362 | new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) 363 | 364 | return new_state_dict 365 | 366 | 367 | def convert_text_enc_state_dict(text_enc_dict): 368 | return text_enc_dict 369 | 370 | if __name__ == "__main__": 371 | parser = argparse.ArgumentParser() 372 | 373 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") 374 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") 375 | parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.") 376 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 377 | parser.add_argument( 378 | "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." 379 | ) 380 | 381 | args = parser.parse_args() 382 | 383 | assert args.model_path is not None, "Must provide a model path!" 384 | 385 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" 386 | 387 | assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!" 388 | 389 | # Path for safetensors 390 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") 391 | #vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors") 392 | text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors") 393 | 394 | # Load models from safetensors if it exists, if it doesn't pytorch 395 | if osp.exists(unet_path): 396 | unet_state_dict = load_file(unet_path, device="cpu") 397 | else: 398 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") 399 | unet_state_dict = torch.load(unet_path, map_location="cpu") 400 | 401 | # if osp.exists(vae_path): 402 | # vae_state_dict = load_file(vae_path, device="cpu") 403 | # else: 404 | # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") 405 | # vae_state_dict = torch.load(vae_path, map_location="cpu") 406 | 407 | if osp.exists(text_enc_path): 408 | text_enc_dict = load_file(text_enc_path, device="cpu") 409 | else: 410 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") 411 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") 412 | 413 | # Convert the UNet model 414 | unet_state_dict = convert_unet_state_dict(unet_state_dict) 415 | #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} 416 | 417 | # Convert the VAE model 418 | # vae_state_dict = convert_vae_state_dict(vae_state_dict) 419 | # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} 420 | 421 | # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper 422 | is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict 423 | 424 | if is_v20_model: 425 | 426 | # MODELSCOPE always uses the 2.X encoder, btw --kabachuha 427 | 428 | # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm 429 | text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} 430 | text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) 431 | #text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} 432 | else: 433 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) 434 | #text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} 435 | 436 | # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha 437 | # Save CLIP and the Diffusion model to their own files 438 | 439 | #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} 440 | print ('Saving UNET') 441 | state_dict = {**unet_state_dict} 442 | 443 | if args.half: 444 | state_dict = {k: v.half() for k, v in state_dict.items()} 445 | 446 | if args.use_safetensors: 447 | save_file(state_dict, args.checkpoint_path) 448 | else: 449 | #state_dict = {"state_dict": state_dict} 450 | torch.save(state_dict, args.checkpoint_path) 451 | 452 | # TODO: CLIP conversion doesn't work atm 453 | # print ('Saving CLIP') 454 | # state_dict = {**text_enc_dict} 455 | 456 | # if args.half: 457 | # state_dict = {k: v.half() for k, v in state_dict.items()} 458 | 459 | # if args.use_safetensors: 460 | # save_file(state_dict, args.checkpoint_path) 461 | # else: 462 | # #state_dict = {"state_dict": state_dict} 463 | # torch.save(state_dict, args.clip_checkpoint_path) 464 | 465 | print('Operation successfull') 466 | -------------------------------------------------------------------------------- /utils/lama.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on the implementation from: 3 | https://huggingface.co/spaces/fffiloni/lama-video-watermark-remover/tree/main 4 | 5 | Modules were adapted by Hans Brouwer to only support the final configuration of the model uploaded here: 6 | https://huggingface.co/akhaliq/lama 7 | 8 | Apache License 2.0: https://github.com/advimman/lama/blob/main/LICENSE 9 | 10 | @article{suvorov2021resolution, 11 | title={Resolution-robust Large Mask Inpainting with Fourier Convolutions}, 12 | author={Suvorov, Roman and Logacheva, Elizaveta and Mashikhin, Anton and Remizova, Anastasia and Ashukha, Arsenii and Silvestrov, Aleksei and Kong, Naejin and Goka, Harshith and Park, Kiwoong and Lempitsky, Victor}, 13 | journal={arXiv preprint arXiv:2109.07161}, 14 | year={2021} 15 | } 16 | """ 17 | 18 | import os 19 | import sys 20 | from urllib.request import urlretrieve 21 | 22 | import torch 23 | from einops import rearrange 24 | from PIL import Image 25 | from torch import nn 26 | from torch.nn import functional as F 27 | from torchvision.transforms.functional import to_tensor 28 | from tqdm import tqdm 29 | 30 | from train import export_to_video 31 | 32 | 33 | LAMA_URL = "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt" 34 | LAMA_PATH = "models/lama.ckpt" 35 | 36 | 37 | def download_progress(t): 38 | last_b = [0] 39 | 40 | def update_to(b=1, bsize=1, tsize=None): 41 | if tsize is not None: 42 | t.total = tsize 43 | t.update((b - last_b[0]) * bsize) 44 | last_b[0] = b 45 | 46 | return update_to 47 | 48 | 49 | def download(url, path): 50 | with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=path) as t: 51 | urlretrieve(url, filename=path, reporthook=download_progress(t), data=None) 52 | 53 | 54 | class FourierUnit(nn.Module): 55 | def __init__(self, in_channels, out_channels, groups=1): 56 | super(FourierUnit, self).__init__() 57 | self.groups = groups 58 | self.conv_layer = torch.nn.Conv2d( 59 | in_channels=in_channels * 2, 60 | out_channels=out_channels * 2, 61 | kernel_size=1, 62 | stride=1, 63 | padding=0, 64 | groups=self.groups, 65 | bias=False, 66 | ) 67 | self.bn = torch.nn.BatchNorm2d(out_channels * 2) 68 | self.relu = torch.nn.ReLU(inplace=True) 69 | 70 | def forward(self, x): 71 | batch = x.shape[0] 72 | 73 | # (batch, c, h, w/2+1, 2) 74 | fft_dim = (-2, -1) 75 | ffted = torch.fft.rfftn(x, dim=fft_dim, norm="ortho") 76 | ffted = torch.stack((ffted.real, ffted.imag), dim=-1) 77 | ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) 78 | ffted = ffted.view((batch, -1) + ffted.size()[3:]) 79 | 80 | ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) 81 | ffted = self.relu(self.bn(ffted)) 82 | 83 | # (batch,c, t, h, w/2+1, 2) 84 | ffted = ffted.view((batch, -1, 2) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous() 85 | ffted = torch.complex(ffted[..., 0], ffted[..., 1]) 86 | 87 | ifft_shape_slice = x.shape[-2:] 88 | output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm="ortho") 89 | 90 | return output 91 | 92 | 93 | class SpectralTransform(nn.Module): 94 | def __init__(self, in_channels, out_channels, stride=1, groups=1): 95 | super(SpectralTransform, self).__init__() 96 | self.stride = stride 97 | if stride == 2: 98 | self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) 99 | else: 100 | self.downsample = nn.Identity() 101 | 102 | self.conv1 = nn.Sequential( 103 | nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False), 104 | nn.BatchNorm2d(out_channels // 2), 105 | nn.ReLU(inplace=True), 106 | ) 107 | self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups) 108 | self.conv2 = torch.nn.Conv2d(out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) 109 | 110 | def forward(self, x): 111 | x = self.downsample(x) 112 | x = self.conv1(x) 113 | output = self.fu(x) 114 | output = self.conv2(x + output) 115 | return output 116 | 117 | 118 | class FFC(nn.Module): 119 | def __init__( 120 | self, 121 | in_channels, 122 | out_channels, 123 | kernel_size, 124 | ratio_gin, 125 | ratio_gout, 126 | stride=1, 127 | padding=0, 128 | dilation=1, 129 | groups=1, 130 | bias=False, 131 | padding_type="reflect", 132 | gated=False, 133 | ): 134 | super(FFC, self).__init__() 135 | 136 | assert stride == 1 or stride == 2, "Stride should be 1 or 2." 137 | self.stride = stride 138 | 139 | in_cg = int(in_channels * ratio_gin) 140 | in_cl = in_channels - in_cg 141 | out_cg = int(out_channels * ratio_gout) 142 | out_cl = out_channels - out_cg 143 | 144 | self.ratio_gin = ratio_gin 145 | self.ratio_gout = ratio_gout 146 | self.global_in_num = in_cg 147 | 148 | module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d 149 | self.convl2l = module( 150 | in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type 151 | ) 152 | module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d 153 | self.convl2g = module( 154 | in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type 155 | ) 156 | module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d 157 | self.convg2l = module( 158 | in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type 159 | ) 160 | module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform 161 | self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2) 162 | 163 | self.gated = gated 164 | module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d 165 | self.gate = module(in_channels, 2, 1) 166 | 167 | def forward(self, x): 168 | x_l, x_g = x if type(x) is tuple else (x, 0) 169 | out_xl, out_xg = 0, 0 170 | 171 | if self.gated: 172 | total_input_parts = [x_l] 173 | if torch.is_tensor(x_g): 174 | total_input_parts.append(x_g) 175 | total_input = torch.cat(total_input_parts, dim=1) 176 | 177 | gates = torch.sigmoid(self.gate(total_input)) 178 | g2l_gate, l2g_gate = gates.chunk(2, dim=1) 179 | else: 180 | g2l_gate, l2g_gate = 1, 1 181 | 182 | if self.ratio_gout != 1: 183 | out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate 184 | if self.ratio_gout != 0: 185 | out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) 186 | 187 | return out_xl, out_xg 188 | 189 | 190 | class FFC_BN_ACT(nn.Module): 191 | def __init__( 192 | self, 193 | in_channels, 194 | out_channels, 195 | kernel_size, 196 | ratio_gin=0, 197 | ratio_gout=0, 198 | stride=1, 199 | padding=0, 200 | dilation=1, 201 | groups=1, 202 | bias=False, 203 | norm_layer=nn.BatchNorm2d, 204 | activation_layer=nn.ReLU, 205 | ): 206 | super(FFC_BN_ACT, self).__init__() 207 | self.ffc = FFC( 208 | in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding, dilation, groups, bias 209 | ) 210 | lnorm = nn.Identity if ratio_gout == 1 else norm_layer 211 | gnorm = nn.Identity if ratio_gout == 0 else norm_layer 212 | global_channels = int(out_channels * ratio_gout) 213 | self.bn_l = lnorm(out_channels - global_channels) 214 | self.bn_g = gnorm(global_channels) 215 | 216 | lact = nn.Identity if ratio_gout == 1 else activation_layer 217 | gact = nn.Identity if ratio_gout == 0 else activation_layer 218 | self.act_l = lact(inplace=True) 219 | self.act_g = gact(inplace=True) 220 | 221 | def forward(self, x): 222 | x_l, x_g = self.ffc(x) 223 | x_l = self.act_l(self.bn_l(x_l)) 224 | x_g = self.act_g(self.bn_g(x_g)) 225 | return x_l, x_g 226 | 227 | 228 | class FFCResnetBlock(nn.Module): 229 | def __init__(self, dim, ratio_gin, ratio_gout): 230 | super().__init__() 231 | self.conv1 = FFC_BN_ACT( 232 | dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout 233 | ) 234 | self.conv2 = FFC_BN_ACT( 235 | dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout 236 | ) 237 | 238 | def forward(self, x): 239 | x_l, x_g = x if type(x) is tuple else (x, 0) 240 | id_l, id_g = x_l, x_g 241 | x_l, x_g = self.conv1((x_l, x_g)) 242 | x_l, x_g = self.conv2((x_l, x_g)) 243 | x_l, x_g = id_l + x_l, id_g + x_g 244 | out = x_l, x_g 245 | return out 246 | 247 | 248 | class ConcatTupleLayer(nn.Module): 249 | def forward(self, x): 250 | assert isinstance(x, tuple) 251 | x_l, x_g = x 252 | assert torch.is_tensor(x_l) or torch.is_tensor(x_g) 253 | if not torch.is_tensor(x_g): 254 | return x_l 255 | return torch.cat(x, dim=1) 256 | 257 | 258 | class LargeMaskInpainting(nn.Module): 259 | def __init__(self, input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=18, max_features=1024): 260 | super().__init__() 261 | 262 | model = [nn.ReflectionPad2d(3), FFC_BN_ACT(input_nc, ngf, kernel_size=7)] 263 | 264 | ### downsample 265 | for i in range(n_downsampling): 266 | mult = 2**i 267 | model += [ 268 | FFC_BN_ACT( 269 | min(max_features, ngf * mult), 270 | min(max_features, ngf * mult * 2), 271 | kernel_size=3, 272 | stride=2, 273 | padding=1, 274 | ratio_gout=0.75 if i == n_downsampling - 1 else 0, 275 | ) 276 | ] 277 | 278 | ### resnet blocks 279 | for i in range(n_blocks): 280 | cur_resblock = FFCResnetBlock(min(max_features, ngf * 2**n_downsampling), ratio_gin=0.75, ratio_gout=0.75) 281 | model += [cur_resblock] 282 | 283 | model += [ConcatTupleLayer()] 284 | 285 | ### upsample 286 | for i in range(n_downsampling): 287 | mult = 2 ** (n_downsampling - i) 288 | model += [ 289 | nn.ConvTranspose2d( 290 | min(max_features, ngf * mult), 291 | min(max_features, int(ngf * mult / 2)), 292 | kernel_size=3, 293 | stride=2, 294 | padding=1, 295 | output_padding=1, 296 | ), 297 | nn.BatchNorm2d(min(max_features, int(ngf * mult / 2))), 298 | nn.ReLU(True), 299 | ] 300 | 301 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7), nn.Sigmoid()] 302 | self.model = nn.Sequential(*model) 303 | 304 | def forward(self, img, mask): 305 | masked_img = img * (1 - mask) 306 | masked_img = torch.cat([masked_img, mask], dim=1) 307 | pred = self.model(masked_img) 308 | inpainted = mask * pred + (1 - mask) * img 309 | return inpainted 310 | 311 | 312 | @torch.inference_mode() 313 | def inpaint_watermark(imgs): 314 | if not os.path.exists(LAMA_PATH): 315 | download(LAMA_URL, LAMA_PATH) 316 | 317 | mask = to_tensor(Image.open("./utils/mask.png").convert("L")).unsqueeze(0).to(imgs.device) 318 | if mask.shape[-1] != imgs.shape[-1]: 319 | mask = F.interpolate(mask, size=(imgs.shape[2], imgs.shape[3]), mode="nearest") 320 | mask = mask.expand(imgs.shape[0], 1, mask.shape[2], mask.shape[3]) 321 | 322 | model = LargeMaskInpainting().to(imgs.device) 323 | state_dict = torch.load(LAMA_PATH, map_location=imgs.device)["state_dict"] 324 | g_dict = {k.replace("generator.", ""): v for k, v in state_dict.items() if k.startswith("generator")} 325 | model.load_state_dict(g_dict) 326 | 327 | inpainted = model.forward(imgs, mask) 328 | 329 | return inpainted 330 | 331 | 332 | if __name__ == "__main__": 333 | import decord 334 | 335 | decord.bridge.set_bridge("torch") 336 | 337 | if len(sys.argv) < 2: 338 | print("Usage: python -m utils.lama ") 339 | sys.exit(1) 340 | 341 | video_path = sys.argv[1] 342 | out_path = video_path.replace(".mp4", " inpainted.mp4") 343 | 344 | vr = decord.VideoReader(video_path) 345 | fps = vr.get_avg_fps() 346 | video = rearrange(vr[:], "f h w c -> f c h w").div(255) 347 | 348 | inpainted = inpaint_watermark(video) 349 | inpainted = rearrange(inpainted, "f c h w -> f h w c").clamp(0, 1).mul(255).byte().cpu().numpy() 350 | export_to_video(inpainted, out_path, fps) 351 | -------------------------------------------------------------------------------- /utils/lora_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import warnings 3 | import torch 4 | from typing import Union 5 | from types import SimpleNamespace 6 | from models.unet_3d_condition_mask import UNet3DConditionModel 7 | from transformers import CLIPTextModel 8 | from utils.convert_diffusers_to_original_ms_text_to_video import convert_unet_state_dict, convert_text_enc_state_dict_v20 9 | 10 | from .lora import ( 11 | extract_lora_ups_down, 12 | inject_trainable_lora_extended, 13 | save_lora_weight, 14 | train_patch_pipe, 15 | monkeypatch_or_replace_lora, 16 | monkeypatch_or_replace_lora_extended 17 | ) 18 | 19 | from stable_lora.lora import ( 20 | activate_lora_train, 21 | add_lora_to, 22 | save_lora, 23 | load_lora, 24 | set_mode_group 25 | ) 26 | 27 | FILE_BASENAMES = ['unet', 'text_encoder'] 28 | LORA_FILE_TYPES = ['.pt', '.safetensors'] 29 | CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r'] 30 | STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias'] 31 | 32 | lora_versions = dict( 33 | stable_lora = "stable_lora", 34 | cloneofsimo = "cloneofsimo" 35 | ) 36 | 37 | lora_func_types = dict( 38 | loader = "loader", 39 | injector = "injector" 40 | ) 41 | 42 | lora_args = dict( 43 | model = None, 44 | loras = None, 45 | target_replace_module = [], 46 | target_module = [], 47 | r = 4, 48 | search_class = [torch.nn.Linear], 49 | dropout = 0, 50 | lora_bias = 'none' 51 | ) 52 | 53 | LoraVersions = SimpleNamespace(**lora_versions) 54 | LoraFuncTypes = SimpleNamespace(**lora_func_types) 55 | 56 | LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo] 57 | LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector] 58 | 59 | def filter_dict(_dict, keys=[]): 60 | if len(keys) == 0: 61 | assert "Keys cannot empty for filtering return dict." 62 | 63 | for k in keys: 64 | if k not in lora_args.keys(): 65 | assert f"{k} does not exist in available LoRA arguments" 66 | 67 | return {k: v for k, v in _dict.items() if k in keys} 68 | 69 | class LoraHandler(object): 70 | def __init__( 71 | self, 72 | version: LORA_VERSIONS = LoraVersions.cloneofsimo, 73 | use_unet_lora: bool = False, 74 | use_text_lora: bool = False, 75 | save_for_webui: bool = False, 76 | only_for_webui: bool = False, 77 | lora_bias: str = 'none', 78 | unet_replace_modules: list = ['UNet3DConditionModel'], 79 | text_encoder_replace_modules: list = ['CLIPEncoderLayer'] 80 | ): 81 | self.version = version 82 | self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader) 83 | self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector) 84 | self.lora_bias = lora_bias 85 | self.use_unet_lora = use_unet_lora 86 | self.use_text_lora = use_text_lora 87 | self.save_for_webui = save_for_webui 88 | self.only_for_webui = only_for_webui 89 | self.unet_replace_modules = unet_replace_modules 90 | self.text_encoder_replace_modules = text_encoder_replace_modules 91 | self.use_lora = any([use_text_lora, use_unet_lora]) 92 | 93 | if self.use_lora: 94 | print(f"Using LoRA Version: {self.version}") 95 | 96 | def is_cloneofsimo_lora(self): 97 | return self.version == LoraVersions.cloneofsimo 98 | 99 | def is_stable_lora(self): 100 | return self.version == LoraVersions.stable_lora 101 | 102 | def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader): 103 | 104 | if self.is_cloneofsimo_lora(): 105 | 106 | if func_type == LoraFuncTypes.loader: 107 | return monkeypatch_or_replace_lora_extended 108 | 109 | if func_type == LoraFuncTypes.injector: 110 | return inject_trainable_lora_extended 111 | 112 | if self.is_stable_lora(): 113 | 114 | if func_type == LoraFuncTypes.loader: 115 | return load_lora 116 | 117 | if func_type == LoraFuncTypes.injector: 118 | return add_lora_to 119 | 120 | assert "LoRA Version does not exist." 121 | 122 | def check_lora_ext(self, lora_file: str): 123 | return lora_file.endswith(tuple(LORA_FILE_TYPES)) 124 | 125 | def get_lora_file_path( 126 | self, 127 | lora_path: str, 128 | model: Union[UNet3DConditionModel, CLIPTextModel] 129 | ): 130 | if os.path.exists(lora_path): 131 | lora_filenames = [fns for fns in os.listdir(lora_path)] 132 | is_lora = self.check_lora_ext(lora_path) 133 | 134 | is_unet = isinstance(model, UNet3DConditionModel) 135 | is_text = isinstance(model, CLIPTextModel) 136 | idx = 0 if is_unet else 1 137 | 138 | base_name = FILE_BASENAMES[idx] 139 | 140 | for lora_filename in lora_filenames: 141 | is_lora = self.check_lora_ext(lora_filename) 142 | if not is_lora: 143 | continue 144 | 145 | if base_name in lora_filename: 146 | return os.path.join(lora_path, lora_filename) 147 | 148 | return None 149 | 150 | def handle_lora_load(self, file_name:str, lora_loader_args: dict = None): 151 | self.lora_loader(**lora_loader_args) 152 | print(f"Successfully loaded LoRA from: {file_name}") 153 | 154 | def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,): 155 | try: 156 | lora_file = self.get_lora_file_path(lora_path, model) 157 | 158 | if lora_file is not None: 159 | lora_loader_args.update({"lora_path": lora_file}) 160 | self.handle_lora_load(lora_file, lora_loader_args) 161 | 162 | else: 163 | print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...") 164 | 165 | except Exception as e: 166 | print(f"An error occured while loading a LoRA file: {e}") 167 | 168 | def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias): 169 | return_dict = lora_args.copy() 170 | 171 | if self.is_cloneofsimo_lora(): 172 | return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS) 173 | return_dict.update({ 174 | "model": model, 175 | "loras": self.get_lora_file_path(lora_path, model), 176 | "target_replace_module": replace_modules, 177 | "r": r 178 | }) 179 | 180 | if self.is_stable_lora(): 181 | KEYS = ['model', 'lora_path'] 182 | return_dict = filter_dict(return_dict, KEYS) 183 | 184 | return_dict.update({'model': model, 'lora_path': lora_path}) 185 | 186 | return return_dict 187 | 188 | def do_lora_injection( 189 | self, 190 | model, 191 | replace_modules, 192 | bias='none', 193 | dropout=0, 194 | r=4, 195 | lora_loader_args=None, 196 | ): 197 | REPLACE_MODULES = replace_modules 198 | 199 | params = None 200 | negation = None 201 | is_injection_hybrid = False 202 | 203 | if self.is_cloneofsimo_lora(): 204 | is_injection_hybrid = True 205 | injector_args = lora_loader_args 206 | 207 | params, negation = self.lora_injector(**injector_args) 208 | for _up, _down in extract_lora_ups_down( 209 | model, 210 | target_replace_module=REPLACE_MODULES): 211 | 212 | if all(x is not None for x in [_up, _down]): 213 | print(f"Lora successfully injected into {model.__class__.__name__}.") 214 | 215 | break 216 | 217 | return params, negation, is_injection_hybrid 218 | 219 | if self.is_stable_lora(): 220 | injector_args = lora_args.copy() 221 | injector_args = filter_dict(injector_args, keys=STABLE_LORA_KEYS) 222 | 223 | SEARCH_CLASS = [torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Embedding] 224 | 225 | injector_args.update({ 226 | "model": model, 227 | "target_module": REPLACE_MODULES, 228 | "search_class": SEARCH_CLASS, 229 | "r": r, 230 | "dropout": dropout, 231 | "lora_bias": self.lora_bias 232 | }) 233 | 234 | activator = self.lora_injector(**injector_args) 235 | activator() 236 | 237 | return params, negation, is_injection_hybrid 238 | 239 | def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16): 240 | 241 | params = None 242 | negation = None 243 | 244 | lora_loader_args = self.get_lora_func_args( 245 | lora_path, 246 | use_lora, 247 | model, 248 | replace_modules, 249 | r, 250 | dropout, 251 | self.lora_bias 252 | ) 253 | if use_lora: 254 | params, negation, is_injection_hybrid = self.do_lora_injection( 255 | model, 256 | replace_modules, 257 | bias=self.lora_bias, 258 | lora_loader_args=lora_loader_args, 259 | dropout=dropout, 260 | r=r 261 | ) 262 | 263 | if not is_injection_hybrid: 264 | self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args) 265 | 266 | params = model if params is None else params 267 | return params, negation 268 | 269 | 270 | def deactivate_lora_train(self, models, deactivate=True): 271 | """ 272 | Usage: Use before and after sampling previews. 273 | Currently only available for Stable LoRA. 274 | """ 275 | if self.is_stable_lora(): 276 | set_mode_group(models, not deactivate) 277 | 278 | def save_cloneofsimo_lora(self, model, save_path, step): 279 | 280 | def save_lora(model, name, condition, replace_modules, step, save_path): 281 | if condition and replace_modules is not None: 282 | save_path = f"{save_path}/{step}_{name}.pt" 283 | save_lora_weight(model, save_path, replace_modules) 284 | 285 | save_lora( 286 | model.unet, 287 | FILE_BASENAMES[0], 288 | self.use_unet_lora, 289 | self.unet_replace_modules, 290 | step, 291 | save_path, 292 | ) 293 | save_lora( 294 | model.text_encoder, 295 | FILE_BASENAMES[1], 296 | self.use_text_lora, 297 | self.text_encoder_replace_modules, 298 | step, 299 | save_path 300 | ) 301 | 302 | train_patch_pipe(model, self.use_unet_lora, self.use_text_lora) 303 | 304 | def save_stable_lora( 305 | self, 306 | model, 307 | step, 308 | name, 309 | save_path = '', 310 | save_for_webui=False, 311 | only_for_webui=False 312 | ): 313 | import uuid 314 | 315 | save_filename = f"{step}_{name}" 316 | lora_metadata = metadata = { 317 | "stable_lora_text_to_video": "v1", 318 | "lora_name": name + "_" + uuid.uuid4().hex.lower()[:5] 319 | } 320 | save_lora( 321 | unet=model.unet, 322 | text_encoder=model.text_encoder, 323 | save_text_weights=self.use_text_lora, 324 | output_dir=save_path, 325 | lora_filename=save_filename, 326 | lora_bias=self.lora_bias, 327 | save_for_webui=self.save_for_webui, 328 | only_webui=self.only_for_webui, 329 | metadata=lora_metadata, 330 | unet_dict_converter=convert_unet_state_dict, 331 | text_dict_converter=convert_text_enc_state_dict_v20 332 | ) 333 | 334 | def save_lora_weights(self, model: None, save_path: str ='',step: str = ''): 335 | save_path = f"{save_path}/lora" 336 | os.makedirs(save_path, exist_ok=True) 337 | 338 | if self.is_cloneofsimo_lora(): 339 | if any([self.save_for_webui, self.only_for_webui]): 340 | warnings.warn( 341 | """ 342 | You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention. 343 | Only 'stable_lora' is supported for saving to a compatible webui file. 344 | """ 345 | ) 346 | self.save_cloneofsimo_lora(model, save_path, step) 347 | 348 | if self.is_stable_lora(): 349 | name = 'lora_text_to_video' 350 | self.save_stable_lora(model, step, name, save_path) -------------------------------------------------------------------------------- /utils/ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from PIL import Image, ImageDraw, ImageFont 19 | import cv2 20 | import abc 21 | from typing import Optional, Union, Tuple, List, Callable, Dict 22 | #from IPython.display import display 23 | from tqdm.notebook import tqdm 24 | from diffusers.models.cross_attention import CrossAttention 25 | from diffusers.utils import PIL_INTERPOLATION 26 | 27 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 28 | h, w, c = image.shape 29 | offset = int(h * .2) 30 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 31 | font = cv2.FONT_HERSHEY_SIMPLEX 32 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 33 | img[:h] = image 34 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 35 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 36 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) 37 | return img 38 | 39 | 40 | def prepare_image(image, width, height, batch_size, num_videos_per_prompt, device, dtype, do_classifier_free_guidance=True 41 | ): 42 | if not isinstance(image, torch.Tensor): 43 | if isinstance(image, Image.Image): 44 | image = [image] 45 | 46 | if isinstance(image[0], Image.Image): 47 | images = [] 48 | 49 | for image_ in image: 50 | image_ = image_.convert("RGB") 51 | image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) 52 | image_ = np.array(image_) 53 | image_ = image_[None, :] 54 | images.append(image_) 55 | 56 | image = images 57 | 58 | image = np.concatenate(image, axis=0) 59 | image = np.array(image).astype(np.float32) / 255.0 60 | image = image.transpose(0, 3, 1, 2) 61 | image = torch.from_numpy(image) 62 | elif isinstance(image[0], torch.Tensor): 63 | image = torch.cat(image, dim=0) 64 | 65 | image_batch_size = image.shape[0] 66 | 67 | if image_batch_size == 1: 68 | repeat_by = batch_size 69 | else: 70 | # image batch size is the same as prompt batch size 71 | repeat_by = num_videos_per_prompt 72 | 73 | image = image.repeat_interleave(repeat_by, dim=0) 74 | 75 | image = image.to(device=device, dtype=dtype) 76 | 77 | if do_classifier_free_guidance: 78 | image = torch.cat([image] * 2) 79 | 80 | return image 81 | 82 | def view_images(images, save_name, num_rows=1, offset_ratio=0.02): 83 | if type(images) is list: 84 | num_empty = len(images) % num_rows 85 | elif images.ndim == 4: 86 | num_empty = images.shape[0] % num_rows 87 | else: 88 | images = [images] 89 | num_empty = 0 90 | 91 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 92 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 93 | num_items = len(images) 94 | #print(images[0].shape) 95 | h, w, c = images[0].shape 96 | offset = int(h * offset_ratio) 97 | num_cols = num_items // num_rows 98 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 99 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 100 | for i in range(num_rows): 101 | for j in range(num_cols): 102 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 103 | i * num_cols + j] 104 | 105 | pil_img = Image.fromarray(image_) 106 | pil_img.save('output/{}.png'.format(save_name)) 107 | #display(pil_img) 108 | 109 | 110 | def diffusion_step(model,latents, context, t, guidance_scale, control_img,low_resource=False, 111 | control=False): 112 | controlnet_conditioning_scale = 1.0 113 | image = prepare_image( 114 | control_img, 115 | 512, 116 | 512, 117 | 1, 118 | 1, 119 | model.device, 120 | model.unet.dtype, 121 | ) 122 | if low_resource: 123 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 124 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 125 | else: 126 | latents_input = torch.cat([latents] * 2) 127 | #latents_input = model.scheduler.scale_model_input(latents_input, t) 128 | #print(latent_model_input.shape, context.shape, image.shape) 129 | if control: 130 | down_block_res_samples, mid_block_res_sample = model.controlnet( 131 | latents_input, 132 | t, 133 | encoder_hidden_states=context, 134 | controlnet_cond=image, 135 | return_dict=False, 136 | ) 137 | down_block_res_samples = [ 138 | down_block_res_sample * controlnet_conditioning_scale 139 | for down_block_res_sample in down_block_res_samples 140 | ] 141 | mid_block_res_sample *= controlnet_conditioning_scale 142 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, 143 | down_block_additional_residuals=down_block_res_samples, 144 | mid_block_additional_residual=mid_block_res_sample,)["sample"] 145 | else: 146 | noise_pred = model.unet(latents_input,t, encoder_hidden_states=context)["sample"] 147 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 148 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 149 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 150 | #latents = controller.step_callback(latents) 151 | return latents 152 | 153 | 154 | def latent2image(vae, latents): 155 | latents = 1 / 0.18215 * latents 156 | image = vae.decode(latents)['sample'] 157 | image = (image / 2 + 0.5).clamp(0, 1) 158 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 159 | image = (image * 255).astype(np.uint8) 160 | return image 161 | 162 | 163 | def init_latent(latent, model, height, width, generator, batch_size): 164 | if latent is None: 165 | latent = torch.randn( 166 | (1, model.unet.in_channels, height // 8, width // 8), 167 | generator=generator, 168 | ) 169 | #print(latent.shape,batch_size,model.unet.in_channels) 170 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 171 | return latent, latents 172 | 173 | 174 | 175 | 176 | class AttentionControl(abc.ABC): 177 | 178 | def step_callback(self, x_t): 179 | return x_t 180 | 181 | def between_steps(self): 182 | return 183 | 184 | @property 185 | def num_uncond_att_layers(self): 186 | return 0 187 | 188 | @abc.abstractmethod 189 | def forward (self, attn, is_cross: bool, place_in_unet: str): 190 | raise NotImplementedError 191 | 192 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 193 | #self.reset() 194 | if self.cur_att_layer >= self.num_uncond_att_layers: 195 | h = attn.shape[0] 196 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 197 | self.cur_att_layer += 1 198 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 199 | self.cur_att_layer = 0 200 | self.cur_step += 1 201 | self.between_steps() 202 | return attn 203 | 204 | def reset(self): 205 | self.cur_step = 0 206 | self.cur_att_layer = 0 207 | 208 | def __init__(self): 209 | self.cur_step = 0 210 | self.num_att_layers = -1 211 | self.cur_att_layer = 0 212 | 213 | 214 | class AttentionStore(AttentionControl): 215 | 216 | @staticmethod 217 | def get_empty_store(): 218 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 219 | "down_self": [], "mid_self": [], "up_self": []} 220 | 221 | def forward(self, attn, is_cross: bool, place_in_unet: str): 222 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 223 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 224 | self.step_store[key].append(attn) 225 | return attn 226 | 227 | def between_steps(self): 228 | if len(self.attention_store) == 0: 229 | self.attention_store = self.step_store 230 | else: 231 | for key in self.attention_store: 232 | for i in range(len(self.attention_store[key])): 233 | self.attention_store[key][i] += self.step_store[key][i] 234 | self.step_store = self.get_empty_store() 235 | 236 | def get_average_attention(self): 237 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 238 | return average_attention 239 | 240 | 241 | def reset(self): 242 | super(AttentionStore, self).reset() 243 | self.step_store = self.get_empty_store() 244 | self.attention_store = {} 245 | 246 | def __init__(self): 247 | super(AttentionStore, self).__init__() 248 | self.step_store = self.get_empty_store() 249 | self.attention_store = {} 250 | 251 | 252 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 253 | if type(image_path) is str: 254 | image = np.array(Image.open(image_path))[:, :, :3] 255 | else: 256 | image = image_path 257 | h, w, c = image.shape 258 | left = min(left, w-1) 259 | right = min(right, w - left - 1) 260 | top = min(top, h - left - 1) 261 | bottom = min(bottom, h - top - 1) 262 | image = image[top:h-bottom, left:w-right] 263 | h, w, c = image.shape 264 | if h < w: 265 | offset = (w - h) // 2 266 | image = image[:, offset:offset + h] 267 | elif w < h: 268 | offset = (h - w) // 2 269 | image = image[offset:offset + w] 270 | image = np.array(Image.fromarray(image).resize((512, 512))) 271 | return image 272 | 273 | class P2PCrossAttnProcessor: 274 | 275 | def __init__(self, controller, place_in_unet): 276 | super().__init__() 277 | self.controller = controller 278 | self.place_in_unet = place_in_unet 279 | 280 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 281 | batch_size, sequence_length, _ = hidden_states.shape 282 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length,batch_size=1) 283 | 284 | query = attn.to_q(hidden_states) 285 | 286 | is_cross = encoder_hidden_states is not None 287 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 288 | key = attn.to_k(encoder_hidden_states) 289 | value = attn.to_v(encoder_hidden_states) 290 | 291 | query = attn.head_to_batch_dim(query) 292 | key = attn.head_to_batch_dim(key) 293 | value = attn.head_to_batch_dim(value) 294 | 295 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 296 | 297 | # one line change 298 | self.controller(attention_probs, is_cross, self.place_in_unet) 299 | 300 | hidden_states = torch.bmm(attention_probs, value) 301 | hidden_states = attn.batch_to_head_dim(hidden_states) 302 | 303 | # linear proj 304 | hidden_states = attn.to_out[0](hidden_states) 305 | # dropout 306 | hidden_states = attn.to_out[1](hidden_states) 307 | 308 | return hidden_states 309 | 310 | def register_attention_control(model, controller,controller1): 311 | attn_procs = {} 312 | cross_att_count = 0 313 | for name in model.unet.attn_processors.keys(): 314 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 315 | if name.startswith("mid_block"): 316 | hidden_size = model.unet.config.block_out_channels[-1] 317 | place_in_unet = "mid" 318 | elif name.startswith("up_blocks"): 319 | block_id = int(name[len("up_blocks.")]) 320 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 321 | place_in_unet = "up" 322 | elif name.startswith("down_blocks"): 323 | block_id = int(name[len("down_blocks.")]) 324 | hidden_size = model.unet.config.block_out_channels[block_id] 325 | place_in_unet = "down" 326 | else: 327 | continue 328 | cross_att_count += 1 329 | attn_procs[name] = P2PCrossAttnProcessor( 330 | controller=controller, place_in_unet=place_in_unet 331 | ) 332 | model.unet.set_attn_processor(attn_procs) 333 | controller.num_att_layers = cross_att_count 334 | 335 | 336 | attn_procs = {} 337 | cross_att_count = 0 338 | for name in model.controlnet.attn_processors.keys(): 339 | cross_attention_dim = None if name.endswith("attn1.processor") else model.controlnet.config.cross_attention_dim 340 | if name.startswith("mid_block"): 341 | hidden_size = model.controlnet.config.block_out_channels[-1] 342 | place_in_unet = "mid" 343 | elif name.startswith("up_blocks"): 344 | block_id = int(name[len("up_blocks.")]) 345 | hidden_size = list(reversed(model.controlnet.config.block_out_channels))[block_id] 346 | place_in_unet = "up" 347 | elif name.startswith("down_blocks"): 348 | block_id = int(name[len("down_blocks.")]) 349 | hidden_size = model.controlnet.config.block_out_channels[block_id] 350 | place_in_unet = "down" 351 | else: 352 | continue 353 | cross_att_count += 1 354 | attn_procs[name] = P2PCrossAttnProcessor( 355 | controller=controller1, place_in_unet=place_in_unet 356 | ) 357 | 358 | #model.unet.set_attn_processor(attn_procs) 359 | model.controlnet.set_attn_processor(attn_procs) 360 | controller1.num_att_layers = cross_att_count 361 | 362 | 363 | def get_word_inds(text: str, word_place: int, tokenizer): 364 | split_text = text.split(" ") 365 | if type(word_place) is str: 366 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 367 | elif type(word_place) is int: 368 | word_place = [word_place] 369 | out = [] 370 | if len(word_place) > 0: 371 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 372 | cur_len, ptr = 0, 0 373 | 374 | for i in range(len(words_encode)): 375 | cur_len += len(words_encode[i]) 376 | if ptr in word_place: 377 | out.append(i + 1) 378 | if cur_len >= len(split_text[ptr]): 379 | ptr += 1 380 | cur_len = 0 381 | return np.array(out) 382 | 383 | 384 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 385 | word_inds: Optional[torch.Tensor]=None): 386 | if type(bounds) is float: 387 | bounds = 0, bounds 388 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 389 | if word_inds is None: 390 | word_inds = torch.arange(alpha.shape[2]) 391 | alpha[: start, prompt_ind, word_inds] = 0 392 | alpha[start: end, prompt_ind, word_inds] = 1 393 | alpha[end:, prompt_ind, word_inds] = 0 394 | return alpha 395 | 396 | 397 | def get_time_words_attention_alpha(prompts, num_steps, 398 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 399 | tokenizer, max_num_words=77): 400 | if type(cross_replace_steps) is not dict: 401 | cross_replace_steps = {"default_": cross_replace_steps} 402 | if "default_" not in cross_replace_steps: 403 | cross_replace_steps["default_"] = (0., 1.) 404 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 405 | for i in range(len(prompts) - 1): 406 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 407 | i) 408 | for key, item in cross_replace_steps.items(): 409 | if key != "default_": 410 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 411 | for i, ind in enumerate(inds): 412 | if len(ind) > 0: 413 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 414 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 415 | return alpha_time_words 416 | -------------------------------------------------------------------------------- /utils/seq_aligner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class ScoreParams: 19 | 20 | def __init__(self, gap, match, mismatch): 21 | self.gap = gap 22 | self.match = match 23 | self.mismatch = mismatch 24 | 25 | def mis_match_char(self, x, y): 26 | if x != y: 27 | return self.mismatch 28 | else: 29 | return self.match 30 | 31 | 32 | def get_matrix(size_x, size_y, gap): 33 | matrix = [] 34 | for i in range(len(size_x) + 1): 35 | sub_matrix = [] 36 | for j in range(len(size_y) + 1): 37 | sub_matrix.append(0) 38 | matrix.append(sub_matrix) 39 | for j in range(1, len(size_y) + 1): 40 | matrix[0][j] = j*gap 41 | for i in range(1, len(size_x) + 1): 42 | matrix[i][0] = i*gap 43 | return matrix 44 | 45 | 46 | def get_matrix(size_x, size_y, gap): 47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 50 | return matrix 51 | 52 | 53 | def get_traceback_matrix(size_x, size_y): 54 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) 55 | matrix[0, 1:] = 1 56 | matrix[1:, 0] = 2 57 | matrix[0, 0] = 4 58 | return matrix 59 | 60 | 61 | def global_align(x, y, score): 62 | matrix = get_matrix(len(x), len(y), score.gap) 63 | trace_back = get_traceback_matrix(len(x), len(y)) 64 | for i in range(1, len(x) + 1): 65 | for j in range(1, len(y) + 1): 66 | left = matrix[i, j - 1] + score.gap 67 | up = matrix[i - 1, j] + score.gap 68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 69 | matrix[i, j] = max(left, up, diag) 70 | if matrix[i, j] == left: 71 | trace_back[i, j] = 1 72 | elif matrix[i, j] == up: 73 | trace_back[i, j] = 2 74 | else: 75 | trace_back[i, j] = 3 76 | return matrix, trace_back 77 | 78 | 79 | def get_aligned_sequences(x, y, trace_back): 80 | x_seq = [] 81 | y_seq = [] 82 | i = len(x) 83 | j = len(y) 84 | mapper_y_to_x = [] 85 | while i > 0 or j > 0: 86 | if trace_back[i, j] == 3: 87 | x_seq.append(x[i-1]) 88 | y_seq.append(y[j-1]) 89 | i = i-1 90 | j = j-1 91 | mapper_y_to_x.append((j, i)) 92 | elif trace_back[i][j] == 1: 93 | x_seq.append('-') 94 | y_seq.append(y[j-1]) 95 | j = j-1 96 | mapper_y_to_x.append((j, -1)) 97 | elif trace_back[i][j] == 2: 98 | x_seq.append(x[i-1]) 99 | y_seq.append('-') 100 | i = i-1 101 | elif trace_back[i][j] == 4: 102 | break 103 | mapper_y_to_x.reverse() 104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 105 | 106 | 107 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 108 | x_seq = tokenizer.encode(x) 109 | y_seq = tokenizer.encode(y) 110 | score = ScoreParams(0, 1, -1) 111 | matrix, trace_back = global_align(x_seq, y_seq, score) 112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 113 | alphas = torch.ones(max_len) 114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 115 | mapper = torch.zeros(max_len, dtype=torch.int64) 116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1] 117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) 118 | return mapper, alphas 119 | 120 | 121 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 122 | x_seq = prompts[0] 123 | mappers, alphas = [], [] 124 | for i in range(1, len(prompts)): 125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 126 | mappers.append(mapper) 127 | alphas.append(alpha) 128 | return torch.stack(mappers), torch.stack(alphas) 129 | 130 | 131 | def get_word_inds(text: str, word_place: int, tokenizer): 132 | split_text = text.split(" ") 133 | if type(word_place) is str: 134 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 135 | elif type(word_place) is int: 136 | word_place = [word_place] 137 | out = [] 138 | if len(word_place) > 0: 139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 140 | cur_len, ptr = 0, 0 141 | 142 | for i in range(len(words_encode)): 143 | cur_len += len(words_encode[i]) 144 | if ptr in word_place: 145 | out.append(i + 1) 146 | if cur_len >= len(split_text[ptr]): 147 | ptr += 1 148 | cur_len = 0 149 | return np.array(out) 150 | 151 | 152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 153 | words_x = x.split(' ') 154 | words_y = y.split(' ') 155 | if len(words_x) != len(words_y): 156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" 157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") 158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 161 | mapper = np.zeros((max_len, max_len)) 162 | i = j = 0 163 | cur_inds = 0 164 | while i < max_len and j < max_len: 165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 167 | if len(inds_source_) == len(inds_target_): 168 | mapper[inds_source_, inds_target_] = 1 169 | else: 170 | ratio = 1 / len(inds_target_) 171 | for i_t in inds_target_: 172 | mapper[inds_source_, i_t] = ratio 173 | cur_inds += 1 174 | i += len(inds_source_) 175 | j += len(inds_target_) 176 | elif cur_inds < len(inds_source): 177 | mapper[i, j] = 1 178 | i += 1 179 | j += 1 180 | else: 181 | mapper[j, j] = 1 182 | i += 1 183 | j += 1 184 | 185 | return torch.from_numpy(mapper).float() 186 | 187 | 188 | 189 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 190 | x_seq = prompts[0] 191 | mappers = [] 192 | for i in range(1, len(prompts)): 193 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 194 | mappers.append(mapper) 195 | return torch.stack(mappers) 196 | 197 | --------------------------------------------------------------------------------