├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── app_svd.py
├── colab.ipynb
├── compress_video.py
├── docs
├── 4_sr.mp4
├── barbie2.mp4
├── fish.gif
├── fish.jpg
├── fish_mask.png
├── framework.png
├── girl5.mp4
├── labelme.png
├── pig0.mp4
├── qingming2.gif
├── qingming2_label.jpg
├── sample_1.gif
├── sample_1.png
├── sample_2.gif
├── sample_2.png
├── sample_3.gif
└── sample_3.png
├── example
├── barbie.jpg
├── barbie2.jpg
├── deepspeed.yaml
├── example_padded_rgba_pngs
│ ├── apple.png
│ ├── put rgba images here for train_transparent_i2v_stage2.py.txt
│ └── ziyan0.png
├── example_rgba_video_results
│ ├── animated rgba results for our transparent unet.txt
│ ├── apple
│ │ ├── decoded_alpha.webp
│ │ └── decoded_rgba.webp
│ └── ziyan0
│ │ ├── decoded_alpha.webp
│ │ └── decoded_rgba.webp
├── fish1.jpg
├── fish1_label.jpg
├── girl5.jpg
├── hulu2.jpg
├── hulu3.jpg
├── layerdiffuse_stage2_384.yaml
├── pig0.jpg
├── pig0_label.jpg
├── qingming2.jpg
├── qingming2_label.jpg
├── train_mask_motion.yaml
├── train_mask_motion_lora.yaml
├── train_svd.yaml
├── train_svd_mask.yaml
├── train_svd_v2v.yaml
└── validation_file.json
├── models
├── layerdiffuse_VAE.py
├── pipeline.py
├── pipeline_stage2.py
├── unet_3d_blocks.py
└── unet_3d_condition_mask.py
├── requirements.txt
├── run.sh
├── stable_lora
└── lora.py
├── svd_video2video_examples
├── barbie_input.mp4
├── barbie_mask.png
├── barbie_output.mp4
├── car_input.mp4
├── car_mask_1.png
├── car_mask_2.png
├── car_output_1.mp4
├── car_output_2.mp4
├── windmill_input.mp4
├── windmill_mask.png
└── windmill_output.mp4
├── train.py
├── train_lora.py
├── train_svd.py
├── train_transparent_i2v_stage2.py
└── utils
├── __init__.py
├── bucketing.py
├── common.py
├── convert_diffusers_to_original_ms_text_to_video.py
├── dataset.py
├── lama.py
├── lora.py
├── lora_handler.py
├── ptp_utils.py
└── seq_aligner.py
/.gitignore:
--------------------------------------------------------------------------------
1 | output/
2 | models/lama.ckpt
3 | .vscode/
4 | models/model_scope_diffusers/
5 | text-to-video-ms-1.7b/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 | configs
137 | output
138 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Alibaba
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
👉 AnimateAnything: Fine Grained Open Domain Image Animation with Motion Guidance
3 |
4 | [Zuozhuo Dai](), [Zhenghao Zhang](), [Menghao Li](), [Junchao Liao](), [Siyu Zhu](), [Long Qin](), [Weizhi Wang]()
5 |
6 |

7 |


8 |
9 |
10 |
11 | ## Friendship Link 🔥
12 | - We are excited to announce the open-source release of our latest work: [Tora: Trajectory-oriented Diffusion Transformer for Video Generation](https://github.com/alibaba/Tora). It is the first trajectory-oriented DiT framework that concurrently integrates textual, visual, and trajectory conditions for video generation.
13 |
14 | ## Showcases
15 |
16 | https://github.com/alibaba/animate-anything/assets/1107525/e2659674-c813-402a-8a85-e620f0a6a454
17 |
18 |
19 |
20 | Input Image with Mask |
21 | Prompt |
22 | Result |
23 |
24 |
25 |  |
26 |
27 | Barbie watching the camera with a smiling face. |
28 |  |
29 |
30 |
31 |  |
32 |
33 | The cloak swaying in the wind. |
34 |  |
35 |
36 |
37 |  |
38 |
39 | A red fish is swimming. |
40 |  |
41 |
42 |
43 |
44 |
47 |
48 | ## Framework
49 | 
50 |
51 | ## News 🔥
52 | **2024.2.5**: Support multiple GPUs training with Accelerator DeepSpeed. Config DeepSpeed zero_stage 2 and offload_optimizer_device cpu, you can do full finetuning animate-anything with 4x16G V100 GPUs and SVD with 4x24G A10 GPUs now.
53 |
54 | **2023.12.27**: Support finetuning based on SVD (stable video diffusion) model. Update SVD based animate_anything_svd_v1.0
55 |
56 | **2023.12.18**: Update model to animate_anything_512_v1.02
57 |
58 | ## Features Planned
59 | - 💥 Transparent video generatinon. (Take a RGBA image as input and output animated RGBA videos)
60 | - ✅ reproduce Transparent VAE encoder and decoder according to [LayerDiffuse](https://github.com/layerdiffusion/sd-forge-layerdiffuse).
61 | - ✅ finetune 3D-Unet to support the basic RGBA-image-to-RGBA-video capability.
62 | - 💥 Enhanced prompt-following: generating long-detailed captions using LLaVA.
63 | - 💥 Replace the U-Net with DiffusionTransformer (DiT) as the base model.
64 | - 💥 Variable resolutions and aspect ratios.
65 | - 💥 Support Huggingface Demo / Google Colab.
66 | - ✅ support svd video2video Google Colab demo. See colab.ipynb.
67 | - ✅ Support LoRA finetuning.
68 | - etc.
69 |
70 | ## Getting Started
71 | This repository is based on [Text-To-Video-Finetuning](https://github.com/ExponentialML/Text-To-Video-Finetuning.git).
72 |
73 | ### Create Conda Environment (Optional)
74 | It is recommended to install Anaconda.
75 |
76 | **Windows Installation:** https://docs.anaconda.com/anaconda/install/windows/
77 |
78 | **Linux Installation:** https://docs.anaconda.com/anaconda/install/linux/
79 |
80 | ```bash
81 | conda create -n animation python=3.10
82 | conda activate animation
83 | ```
84 |
85 | ### Python Requirements
86 | ```bash
87 | pip install -r requirements.txt
88 | ```
89 |
90 | ## Running inference
91 | Please download the [pretrained model](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/animate_anything_512_v1.02.tar) to output/latent, then run the following command. Please replace the {download_model} to your download model name:
92 | ```bash
93 | python train.py --config output/latent/{download_model}/config.yaml --eval validation_data.prompt_image=example/barbie2.jpg validation_data.prompt='A cartoon girl is talking.'
94 | ```
95 |
96 | To control the motion area, we can use the labelme to generate a binary mask. First, we use labelme to draw the polygon for the reference image.
97 |
98 | 
99 |
100 | Then we run the following command to transform the labelme json file to a mask.
101 |
102 | ```bash
103 | labelme_json_to_dataset qingming2.json
104 | ```
105 | 
106 |
107 | Then run the following command for inference:
108 | ```bash
109 | python train.py --config output/latent/{download_model}/config.yaml --eval validation_data.prompt_image=example/qingming2.jpg validation_data.prompt='Peoples are walking on the street.' validation_data.mask=example/qingming2_label.jpg
110 | ```
111 | 
112 |
113 |
114 | User can adjust the motion strength by using the mask motion model:
115 | ```bash
116 | python train.py --config output/latent/{download_model}/
117 | config.yaml --eval validation_data.prompt_image=example/qingming2.jpg validation_data.prompt='Peoples are walking on the street.' validation_data.mask=example/qingming2_label.jpg validation_data.strength=5
118 | ```
119 | ## Video super resolution
120 | The model output low res videos, you can use video super resolution model to output high res videos. For example, we can use [Real-CUGAN](https://github.com/bilibili/ailab/tree/main/Real-CUGANfor) cartoon style video super resolution:
121 |
122 | ```bash
123 | git clone https://github.com/bilibili/ailab.git
124 | cd ailab/Real-CUGAN
125 | python inference_video.py
126 | ```
127 |
128 | ## Training
129 |
130 | ### Using Captions
131 |
132 | You can use caption files when training with video. Simply place the videos into a folder and create a json with captions like this:
133 |
134 | ```
135 | [
136 | {"caption": "Cute monster character flat design animation video", "video": "000001_000050/1066697179.mp4"},
137 | {"caption": "Landscape of the cherry blossom", "video": "000001_000050/1066688836.mp4"}
138 | ]
139 |
140 | ```
141 | Then in your config, make sure to set dataset_types to video_json and set the video_dir and video json path like this:
142 | ```
143 | - dataset_types:
144 | - video_json
145 | train_data:
146 | video_dir: '/webvid/webvid/data/videos'
147 | video_json: '/webvid/webvid/data/40K.json'
148 | ```
149 | ### Process Automatically
150 |
151 | You can automatically caption the videos using the [Video-BLIP2-Preprocessor Script](https://github.com/ExponentialML/Video-BLIP2-Preprocessor) and set the dataset_types and json_path like this:
152 | ```
153 | - dataset_types:
154 | - video_blip
155 | train_data:
156 | json_path: 'blip_generated.json'
157 | ```
158 |
159 | ### Configuration
160 |
161 | The configuration uses a YAML config borrowed from [Tune-A-Video](https://github.com/showlab/Tune-A-Video) repositories.
162 |
163 | All configuration details are placed in `example/train_mask_motion.yaml`. Each parameter has a definition for what it does.
164 |
165 |
166 | ### Finetuning anymate-anything
167 | You can finetune anymate-anything with text, motion mask, motion strength guidance on your own dataset. The following config requires around 30G GPU RAM. You can reduce the train_batch_size, train_data.width, train_data.height, and n_sample_frames in the config to reduce GPU RAM:
168 | ```
169 | python train.py --config example/train_mask_motion.yaml pretrained_model_path=
170 | ```
171 |
172 | We also support lora finetuning:
173 | ```
174 | python train_lora.py --config example/train_mask_motion_lora.yaml pretrained_model_path=
175 | ```
176 |
177 | ### Finetune Stable Video Diffusion:
178 | Stable Video Diffusion (SVD) img2vid model can generate high resolution videos. However, it does not have the text or motion mask control. You can finetune SVD with motioin mask guidance with the following commands and [pretrained SVD model](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/animate_anything_svd_v1.0.tar). This config requires around 80G GPU RAM.
179 | ```
180 | python train_svd.py --config example/train_svd_mask.yaml pretrained_model_path=
181 | ```
182 |
183 | If you only want to finetune SVD on your own dataset without motion mask control, please use the following config:
184 | ```
185 | python train_svd.py --config example/train_svd.yaml pretrained_model_path=
186 | ```
187 |
188 | ### Multiple GPUs training
189 | I strongly recommend use multiple GPUs training with Accelerator, which will largely decrease the VRAM requirement. Please first config the accelerator with deepspeed. An example config is located in example/deepspeed.yaml.
190 |
191 | And then replace 'python train_xx.py ...' commands above with 'accelerate launch train_xx.py ...', for example:
192 | ```
193 | accelerate launch --config_file example/deepspeed.yaml train_svd.py --config example/train_svd_mask.yaml pretrained_model_path=
194 | ```
195 |
196 | ### SVD video2video
197 | We now release the finetuned vid2vid SVD model, you can try it via the gradio UI.
198 |
199 | Please download the [vid2vid_SVD model](https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/animate_anything_svd_v1.01.tar) and extract it to output/svd/{download_model} and then run the command:
200 | ```
201 | python app_svd.py --config example/train_svd_v2v.yaml pretrained_model_path=output/svd/{download_model}
202 | ```
203 |
204 | We provide several examples in the svd_video2video_examples directory.
205 |
206 | ## Bibtex
207 | Please cite this paper if you find the code is useful for your research:
208 | ```
209 | @misc{dai2023animateanything,
210 | title={AnimateAnything: Fine-Grained Open Domain Image Animation with Motion Guidance},
211 | author={Zuozhuo Dai and Zhenghao Zhang and Yao Yao and Bingxue Qiu and Siyu Zhu and Long Qin and Weizhi Wang},
212 | year={2023},
213 | eprint={2311.12886},
214 | archivePrefix={arXiv},
215 | primaryClass={cs.CV}
216 | }
217 | ```
218 | ## Shoutouts
219 |
220 | - [Text-To-Video-Finetuning](https://github.com/ExponentialML/Text-To-Video-Finetuning.git)
221 | - [Showlab](https://github.com/showlab/Tune-A-Video) and bryandlee[https://github.com/bryandlee/Tune-A-Video] for their Tune-A-Video contribution that made this much easier.
222 | - [lucidrains](https://github.com/lucidrains) for their implementations around video diffusion.
223 | - [cloneofsimo](https://github.com/cloneofsimo) for their diffusers implementation of LoRA.
224 | - [kabachuha](https://github.com/kabachuha) for their conversion scripts, training ideas, and webui works.
225 | - [JCBrouwer](https://github.com/JCBrouwer) Inference implementations.
226 | - [sergiobr](https://github.com/sergiobr) Helpful ideas and bug fixes.
227 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path as osp
4 | import random
5 | from argparse import ArgumentParser
6 | from datetime import datetime
7 | import math
8 |
9 | import gradio as gr
10 | import numpy as np
11 | import torch
12 | from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
13 | from diffusers.image_processor import VaeImageProcessor
14 | from omegaconf import OmegaConf
15 | from PIL import Image
16 | import torchvision.transforms as T
17 | from einops import rearrange, repeat
18 | import imageio
19 |
20 | from models.pipeline import LatentToVideoPipeline
21 | from utils.common import tensor_to_vae_latent, DDPM_forward
22 |
23 | css = """
24 | .toolbutton {
25 | margin-buttom: 0em 0em 0em 0em;
26 | max-width: 2.5em;
27 | min-width: 2.5em !important;
28 | height: 2.5em;
29 | }
30 | """
31 |
32 |
33 | class AnimateController:
34 | def __init__(self, pretrained_model_path: str, validation_data,
35 | output_dir, motion_mask = False, motion_strength = False):
36 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
37 | # as these models are only used for inference, keeping weights in full precision is not required.
38 | device=torch.device("cuda")
39 | self.validation_data = validation_data
40 | self.output_dir = output_dir
41 | self.pipeline = LatentToVideoPipeline.from_pretrained(pretrained_model_path,
42 | torch_dtype=torch.float16, variant="fp16").to(device)
43 | self.sample_idx = 0
44 |
45 | def animate(
46 | self,
47 | init_img,
48 | motion_scale,
49 | prompt_textbox,
50 | negative_prompt_textbox,
51 | sample_step_slider,
52 | cfg_scale_slider,
53 | seed_textbox,
54 | style,
55 | progress=gr.Progress(),
56 | ):
57 |
58 | if seed_textbox != "-1" and seed_textbox != "":
59 | torch.manual_seed(int(seed_textbox))
60 | else:
61 | torch.seed()
62 | seed = torch.initial_seed()
63 |
64 | vae = self.pipeline.vae
65 | diffusion_scheduler = self.pipeline.scheduler
66 | validation_data = self.validation_data
67 | vae_processor = VaeImageProcessor()
68 |
69 | device = vae.device
70 | dtype = vae.dtype
71 |
72 | pimg = Image.fromarray(init_img["background"]).convert('RGB')
73 | width, height = pimg.size
74 | scale = math.sqrt(width*height / (validation_data.height*validation_data.width))
75 | block_size=8
76 | height = round(height/scale/block_size)*block_size
77 | width = round(width/scale/block_size)*block_size
78 | input_image = vae_processor.preprocess(pimg, height, width)
79 | input_image = input_image.unsqueeze(0).to(dtype).to(device)
80 | input_image_latents = tensor_to_vae_latent(input_image, vae)
81 | np_mask = init_img["layers"][0][:,:,3]
82 | np_mask[np_mask!=0] = 255
83 | if np_mask.sum() == 0:
84 | np_mask[:] = 255
85 | save_sample_path = os.path.join(
86 | self.output_dir, f"{self.sample_idx}.mp4")
87 | out_mask_path = os.path.splitext(save_sample_path)[0] + "_mask.jpg"
88 | Image.fromarray(np_mask).save(out_mask_path)
89 |
90 | b, c, _, h, w = input_image_latents.shape
91 | initial_latents, timesteps = DDPM_forward(input_image_latents,
92 | sample_step_slider, validation_data.num_frames, diffusion_scheduler)
93 | mask = T.ToTensor()(np_mask).to(dtype).to(device)
94 | b, c, f, h, w = initial_latents.shape
95 | mask = T.Resize([h, w], antialias=False)(mask)
96 | mask = rearrange(mask, 'b h w -> b 1 1 h w')
97 | motion_strength = motion_scale * mask.mean().item()
98 | print(f"outfile {save_sample_path}, prompt {prompt_textbox}, motion_strength {motion_strength}")
99 | with torch.no_grad():
100 | video_frames, video_latents = self.pipeline(
101 | prompt=prompt_textbox,
102 | latents=initial_latents,
103 | width=width,
104 | height=height,
105 | num_frames=validation_data.num_frames,
106 | num_inference_steps=sample_step_slider,
107 | guidance_scale=cfg_scale_slider,
108 | condition_latent=input_image_latents,
109 | mask=mask,
110 | motion=[motion_strength],
111 | return_dict=False,
112 | timesteps=timesteps,
113 | )
114 |
115 | imageio.mimwrite(save_sample_path, video_frames, fps=8)
116 | self.sample_idx += 1
117 | return save_sample_path
118 |
119 |
120 | def ui(controller):
121 | with gr.Blocks(css=css) as demo:
122 |
123 | gr.HTML(
124 | "Animate Anything
"
125 | )
126 | with gr.Row():
127 | gr.Markdown(
128 | "Project Page " # noqa
129 | "Paper "
130 | "Code " # noqa
131 | "Instructions: 1. Upload image 2. Draw mask on image using draw button. 3. Write prompt. 4.Click generate button. If it is not response, please click again.
"
132 | )
133 |
134 | with gr.Row(equal_height=False):
135 | with gr.Column():
136 | with gr.Row():
137 | init_img = gr.ImageMask(label='Input Image', brush=gr.Brush(default_size=100))
138 | style_dropdown = gr.Dropdown(label='Style', choices=['384', '512'])
139 | with gr.Row():
140 | prompt_textbox = gr.Textbox(label="Prompt", value='moving', lines=1)
141 |
142 | motion_scale_silder = gr.Slider(
143 | label='Motion Strength (Larger value means larger motion but less identity consistency)',
144 | value=5, step=1, minimum=1, maximum=20)
145 |
146 | with gr.Accordion('Advance Options', open=False):
147 | negative_prompt_textbox = gr.Textbox(
148 | value="", label="Negative prompt", lines=2)
149 |
150 | sample_step_slider = gr.Slider(
151 | label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
152 |
153 | cfg_scale_slider = gr.Slider(
154 | label="CFG Scale", value=9, minimum=0, maximum=20)
155 |
156 | with gr.Row():
157 | seed_textbox = gr.Textbox(label="Seed", value=-1)
158 | seed_button = gr.Button(
159 | value="\U0001F3B2", elem_classes="toolbutton")
160 | seed_button.click(
161 | fn=lambda x: random.randint(1, 1e8),
162 | outputs=[seed_textbox],
163 | queue=False
164 | )
165 |
166 | generate_button = gr.Button(
167 | value="Generate", variant='primary')
168 |
169 | result_video = gr.Video(
170 | label="Generated Animation", interactive=False)
171 |
172 | generate_button.click(
173 | fn=controller.animate,
174 | inputs=[
175 | init_img,
176 | motion_scale_silder,
177 | prompt_textbox,
178 | negative_prompt_textbox,
179 | sample_step_slider,
180 | cfg_scale_slider,
181 | seed_textbox,
182 | style_dropdown,
183 | ],
184 | outputs=[result_video]
185 | )
186 |
187 | def create_example(input_list):
188 | return gr.Examples(
189 | examples=input_list,
190 | inputs=[
191 | init_img,
192 | result_video,
193 | prompt_textbox,
194 | style_dropdown,
195 | motion_scale_silder,
196 | ],
197 | )
198 |
199 | gr.Markdown(
200 | '### Merry Christmas!'
201 | )
202 | create_example(
203 | [
204 | [ 'example/pig0.jpg', 'docs/pig0.mp4', 'pigs are talking', '512', 3],
205 | [ 'example/barbie2.jpg', 'docs/barbie2.mp4', 'a girl is talking', '512', 4],
206 | ],
207 |
208 | )
209 |
210 | return demo
211 |
212 |
213 | if __name__ == "__main__":
214 | parser = ArgumentParser()
215 | parser.add_argument('--config', type=str, default='example/config/base.yaml')
216 | parser.add_argument('--server-name', type=str, default='0.0.0.0')
217 | parser.add_argument('--port', type=int, default=7860)
218 | parser.add_argument('--share', action='store_true', default=False)
219 | parser.add_argument('--local-debug', action='store_true')
220 | parser.add_argument('--save-path', default='samples')
221 |
222 | args, unknownargs = parser.parse_known_args()
223 | LOCAL_DEBUG = args.local_debug
224 | args_dict = OmegaConf.load(args.config)
225 | cli_conf = OmegaConf.from_cli()
226 | args_dict = OmegaConf.merge(args_dict, cli_conf)
227 | controller = AnimateController(args_dict.pretrained_model_path, args_dict.validation_data,
228 | args_dict.output_dir, args_dict.motion_mask, args_dict.motion_strength)
229 | demo = ui(controller)
230 | demo.queue(max_size=10)
231 | demo.launch(server_name=args.server_name,
232 | server_port=args.port, max_threads=40,
233 | allowed_paths=['example/barbie2.jpg'],
234 | share=args.share)
235 |
--------------------------------------------------------------------------------
/app_svd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from argparse import ArgumentParser
4 | import math
5 |
6 | import gradio as gr
7 | import torch
8 | from diffusers.image_processor import VaeImageProcessor
9 | from omegaconf import OmegaConf
10 | from PIL import Image
11 | import torchvision.transforms as T
12 | import imageio
13 |
14 | from diffusers import StableVideoDiffusionPipeline
15 | from models.pipeline import TextStableVideoDiffusionPipeline
16 | from einops import rearrange, repeat
17 | from utils.common import read_video
18 |
19 | css = """
20 | .toolbutton {
21 | margin-buttom: 0em 0em 0em 0em;
22 | max-width: 2.5em;
23 | min-width: 2.5em !important;
24 | height: 2.5em;
25 | }
26 | """
27 |
28 |
29 | class AnimateController:
30 | def __init__(self, pretrained_model_path: str, validation_data,
31 | output_dir, motion_mask = False, motion_strength = False):
32 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
33 | # as these models are only used for inference, keeping weights in full precision is not required.
34 | device=torch.device("cuda")
35 | self.validation_data = validation_data
36 | self.output_dir = output_dir
37 | self.pipeline = StableVideoDiffusionPipeline.from_pretrained(pretrained_model_path, torch_dtype=torch.float16, variant="fp16").to(device)
38 | #self.pipeline = StableVideoDiffusionPipeline.from_pretrained(pretrained_model_path).to(device)
39 | self.sample_idx = 0
40 |
41 | def animate(
42 | self,
43 | init_img,
44 | input_video,
45 | sample_step_slider,
46 | seed_textbox,
47 | fps_textbox,
48 | num_frames_textbox,
49 | motion_bucket_id_slider,
50 | progress=gr.Progress(),
51 | ):
52 |
53 | if seed_textbox != "-1" and seed_textbox != "":
54 | torch.manual_seed(int(seed_textbox))
55 | else:
56 | torch.seed()
57 | seed = torch.initial_seed()
58 |
59 | with torch.no_grad():
60 | vae = self.pipeline.vae
61 | validation_data = self.validation_data
62 | validation_data.fps = int(fps_textbox)
63 | validation_data.num_frames = int(num_frames_textbox)
64 | validation_data.motion_bucket_id = int(motion_bucket_id_slider)
65 | vae_processor = VaeImageProcessor()
66 |
67 | device = vae.device
68 | dtype = vae.dtype
69 |
70 | f = validation_data.num_frames
71 | pimg = Image.fromarray(init_img["background"]).convert('RGB')
72 | np_mask = init_img["layers"][0][:,:,3]
73 | np_mask[np_mask!=0] = 255
74 | if np_mask.sum() == 0:
75 | np_mask[:] = 255
76 | if input_video is not None:
77 | frames = read_video(input_video)
78 | frames = [Image.fromarray(f) for f in frames]
79 | pimg = frames[0]
80 | width, height = pimg.size
81 | scale = math.sqrt(width*height / (validation_data.height*validation_data.width))
82 | block_size=64
83 | height = round(height/scale/block_size)*block_size
84 | width = round(width/scale/block_size)*block_size
85 | f = len(frames)
86 |
87 | latents = []
88 | for frame in frames:
89 | input_image = vae_processor.preprocess(frame, height, width)
90 | input_image = input_image.to(dtype).to(device)
91 | input_image_latent = vae.encode(input_image).latent_dist.mode() * vae.config.scaling_factor
92 | latents.append(input_image_latent.unsqueeze(1))
93 | latents = torch.cat(latents, dim=1)
94 | else:
95 | width, height = pimg.size
96 | scale = math.sqrt(width*height / (validation_data.height*validation_data.width))
97 | block_size=64
98 | height = round(height/scale/block_size)*block_size
99 | width = round(width/scale/block_size)*block_size
100 | input_image = vae_processor.preprocess(pimg, height, width)
101 | input_image = input_image.to(dtype).to(device)
102 | input_image_latent = vae.encode(input_image).latent_dist.mode() * vae.config.scaling_factor
103 | latents = repeat(input_image_latent, 'b c h w->b f c h w', f=f)
104 |
105 | b, f, c, h, w = latents.shape
106 |
107 | mask = T.ToTensor()(np_mask).to(dtype).to(device)
108 | mask = T.Resize([h, w], antialias=False)(mask)
109 | mask = repeat(mask, 'b h w -> b f 1 h w', f=f).detach().clone()
110 | mask[:,0] = 0
111 | freeze = repeat(latents[:,0], 'b c h w -> b f c h w', f=f)
112 | condition_latents = latents * (1-mask) + freeze * mask
113 | condition_latents = condition_latents/vae.config.scaling_factor
114 |
115 | motion_mask = self.pipeline.unet.config.in_channels == 9
116 | decode_chunk_size=validation_data.get("decode_chunk_size", 7)
117 | fps=validation_data.get("fps", 7)
118 | motion_bucket_id=validation_data.get("motion_bucket_id", 127)
119 | if motion_mask:
120 | video_frames = TextStableVideoDiffusionPipeline.__call__(
121 | self.pipeline,
122 | image=pimg,
123 | width=width,
124 | height=height,
125 | num_frames=validation_data.num_frames,
126 | num_inference_steps=validation_data.num_inference_steps,
127 | decode_chunk_size=decode_chunk_size,
128 | fps=fps,
129 | motion_bucket_id=motion_bucket_id,
130 | mask=mask,
131 | condition_type="image",
132 | condition_latent=condition_latents
133 | ).frames[0]
134 | else:
135 | video_frames = self.pipeline(
136 | image=pimg,
137 | width=width,
138 | height=height,
139 | num_frames=validation_data.num_frames,
140 | num_inference_steps=validation_data.num_inference_steps,
141 | fps=validation_data.fps,
142 | decode_chunk_size=validation_data.decode_chunk_size,
143 | motion_bucket_id=validation_data.motion_bucket_id,
144 | ).frames[0]
145 |
146 | save_sample_path = os.path.join(
147 | self.output_dir, f"{self.sample_idx}.mp4")
148 | Image.fromarray(np_mask).save(os.path.join(
149 | self.output_dir, f"{self.sample_idx}_label.jpg"))
150 | imageio.mimwrite(save_sample_path, video_frames, fps=7)
151 | self.sample_idx += 1
152 | return save_sample_path
153 |
154 | import cv2
155 |
156 | def get_video_info(video_path):
157 | cap = cv2.VideoCapture(video_path)
158 | if not cap.isOpened():
159 | return None
160 |
161 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
162 | cap.release()
163 |
164 | return length
165 |
166 | def update_num_frames(input_video, num_frames_textbox):
167 | frame_count = get_video_info(input_video)
168 | return frame_count or 14
169 |
170 | def ui(controller):
171 | with gr.Blocks(css=css) as demo:
172 |
173 | gr.HTML(
174 | "
Animate Anything For SVD "
175 | )
176 | with gr.Row():
177 | gr.Markdown(
178 | "Project Page " # noqa
179 | "Paper "
180 | "Code " # noqa
181 | )
182 |
183 | with gr.Row(equal_height=True):
184 | with gr.Column():
185 | init_img = gr.ImageMask(label='Input Image', brush=gr.Brush(default_size=100))
186 | generate_button = gr.Button(
187 | value="Generate", variant='primary')
188 | input_video = gr.Video(label="Input video", interactive=True)
189 |
190 | result_video = gr.Video(
191 | label="Generated Animation", interactive=False)
192 |
193 | with gr.Accordion('Advance Options', open=False):
194 | with gr.Row():
195 | fps_textbox = gr.Number(label="Fps", value=7, minimum=1)
196 | num_frames_textbox = gr.Number(label="Num frames", value=14, minimum=1, maximum=78)
197 |
198 | input_video.upload(
199 | fn=update_num_frames,
200 | inputs=[input_video],
201 | outputs=[num_frames_textbox]
202 | )
203 |
204 | motion_bucket_id_slider = gr.Slider(
205 | label='motion_bucket_id',
206 | value=127, step=1, minimum=0, maximum=511)
207 |
208 | sample_step_slider = gr.Slider(
209 | label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
210 |
211 | with gr.Row():
212 | seed_textbox = gr.Textbox(label="Seed", value=-1)
213 | seed_button = gr.Button(
214 | value="\U0001F3B2", elem_classes="toolbutton")
215 | seed_button.click(
216 | fn=lambda x: random.randint(1, 1e8),
217 | outputs=[seed_textbox],
218 | queue=False
219 | )
220 |
221 |
222 |
223 | generate_button.click(
224 | fn=controller.animate,
225 | inputs=[
226 | init_img,
227 | input_video,
228 | sample_step_slider,
229 | seed_textbox,
230 | fps_textbox,
231 | num_frames_textbox,
232 | motion_bucket_id_slider
233 | ],
234 | outputs=[result_video]
235 | )
236 |
237 | return demo
238 |
239 |
240 | if __name__ == "__main__":
241 | parser = ArgumentParser()
242 | parser.add_argument('--config', type=str, default='example/config/base.yaml')
243 | parser.add_argument('--server-name', type=str, default='0.0.0.0')
244 | parser.add_argument('--port', type=int, default=7860)
245 | parser.add_argument('--share', action='store_true')
246 | parser.add_argument('--local-debug', action='store_true')
247 | parser.add_argument('--save-path', default='samples')
248 |
249 | args, unknownargs = parser.parse_known_args()
250 | LOCAL_DEBUG = args.local_debug
251 | args_dict = OmegaConf.load(args.config)
252 | cli_conf = OmegaConf.from_cli()
253 | args_dict = OmegaConf.merge(args_dict, cli_conf)
254 | controller = AnimateController(args_dict.pretrained_model_path, args_dict.validation_data,
255 | args_dict.output_dir, args_dict.motion_mask, args_dict.motion_strength)
256 | demo = ui(controller)
257 | demo.queue(max_size=10)
258 | demo.launch(server_name=args.server_name,
259 | server_port=args.port, max_threads=40,
260 | )
261 |
--------------------------------------------------------------------------------
/colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": true,
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "[](https://colab.research.google.com/github/dailingx/animate-anything/blob/main/colab.ipynb)"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!git clone https://github.com/alibaba/animate-anything /content/animate-anything\n",
22 | "%cd /content/animate-anything\n",
23 | "\n",
24 | "!pip install -r requirements.txt\n",
25 | "\n",
26 | "!apt -y install -qq aria2\n",
27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://cloudbook-public-production.oss-cn-shanghai.aliyuncs.com/animation/animate_anything_512_v1.02.tar -d output/latent\n",
28 | "!tar -xf output/latent/animate_anything_512_v1.02.tar -C output/latent/\n",
29 | "\n",
30 | "!python app.py --config output/latent/animate_anything_512_v1.02/config.yaml --share"
31 | ],
32 | "metadata": {
33 | "collapsed": false,
34 | "pycharm": {
35 | "name": "#%%\n"
36 | }
37 | }
38 | }
39 | ],
40 | "metadata": {
41 | "accelerator": "GPU",
42 | "colab": {
43 | "gpuType": "T4",
44 | "provenance": []
45 | },
46 | "kernelspec": {
47 | "display_name": "Python 3",
48 | "language": "python",
49 | "name": "python3"
50 | },
51 | "language_info": {
52 | "codemirror_mode": {
53 | "name": "ipython",
54 | "version": 2
55 | },
56 | "file_extension": ".py",
57 | "mimetype": "text/x-python",
58 | "name": "python",
59 | "nbconvert_exporter": "python",
60 | "pygments_lexer": "ipython2",
61 | "version": "2.7.6"
62 | }
63 | },
64 | "nbformat": 4,
65 | "nbformat_minor": 0
66 | }
67 |
--------------------------------------------------------------------------------
/compress_video.py:
--------------------------------------------------------------------------------
1 | """
2 | Used to compress video in: https://github.com/ArrowLuo/CLIP4Clip
3 | Author: ArrowLuo
4 | """
5 | import os
6 | import argparse
7 | import ffmpeg
8 | import subprocess
9 | import time
10 | import multiprocessing
11 | from multiprocessing import Pool
12 | import shutil
13 | import json
14 | try:
15 | from psutil import cpu_count
16 | except:
17 | from multiprocessing import cpu_count
18 | # multiprocessing.freeze_support()
19 |
20 | def compress(paras):
21 | input_video_path, output_video_path = paras
22 | try:
23 | command = ['ffmpeg',
24 | '-y', # (optional) overwrite output file if it exists
25 | '-i', input_video_path,
26 | '-filter:v',
27 | 'scale=\'if(gt(a,1),trunc(oh*a/2)*2,512)\':\'if(gt(a,1),512,trunc(ow*a/2)*2)\'', # scale to 256
28 | '-map', '0:v',
29 | #'-r', '3', # frames per second
30 | output_video_path,
31 | ]
32 | ffmpeg = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
33 | out, err = ffmpeg.communicate()
34 | retcode = ffmpeg.poll()
35 | # print something above for debug
36 | except Exception as e:
37 | raise e
38 |
39 | def prepare_input_output_pairs(input_root, output_root):
40 | input_video_path_list = []
41 | output_video_path_list = []
42 | for root, dirs, files in os.walk(input_root):
43 | for file_name in files:
44 | input_video_path = os.path.join(root, file_name)
45 | output_video_path = os.path.join(output_root, file_name)
46 | output_video_path = os.path.splitext(output_video_path)[0] + ".mp4"
47 | if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0:
48 | pass
49 | else:
50 | input_video_path_list.append(input_video_path)
51 | output_video_path_list.append(output_video_path)
52 | return input_video_path_list, output_video_path_list
53 |
54 | def msvd():
55 | captions = pickle.load(open('raw-captions.pkl','rb'))
56 | outdir = "/data/datasets/msvd/videos_mp4"
57 | for key in captions:
58 | outpath = os.path.join(outdir, key+".txt")
59 | with open(outpath, 'w') as f:
60 | for line in captions[key]:
61 | f.write(" ".join(line)+"\n")
62 |
63 | def webvid():
64 |
65 | df = pd.read_csv('/webvid/results_2M_train_1/0.csv')
66 | df['rel_fn'] = df.apply(lambda x: os.path.join(str(x['page_dir']), str(x['videoid'])), axis=1)
67 |
68 | df['rel_fn'] = df['rel_fn'] + '.mp4'
69 | # remove nan
70 | df.dropna(subset=['page_dir'], inplace=True)
71 |
72 | playlists_to_dl = np.sort(df['page_dir'].unique())
73 |
74 | vjson = []
75 | video_dir = '/webvid/webvid/data/videos'
76 | for page_dir in playlists_to_dl:
77 | vid_dir_t = os.path.join(video_dir, page_dir)
78 | pdf = df[df['page_dir'] == page_dir]
79 | if len(pdf) > 0:
80 | for idx, row in pdf.iterrows():
81 | video_fp = os.path.join(vid_dir_t, str(row['videoid']) + '.mp4')
82 | if os.path.isfile(video_fp):
83 | caption = row['name']
84 | video_path = os.path.join(page_dir, str(row['videoid'])+'.mp4')
85 | vjson.append({'caption':caption,'video':video_path})
86 | with open('/webvid/webvid/data/2M.json', 'w') as f:
87 | json.dump(vjson, f)
88 |
89 | def webvid20k():
90 | j = json.load(open('/webvid/webvid/data/2M.json'))
91 | idir = '/webvid/webvid/data/videos'
92 |
93 | v2c = []
94 | for item in j:
95 | caption = item['caption']
96 | video = item['video']
97 | if os.path.exists(os.path.join(idir, video)):
98 | v2c.append(item)
99 | print("video numbers", len(v2c))
100 | with open('/webvid/webvid/data/40K.json', 'w') as f:
101 | json.dump(v2c, f)
102 |
103 |
104 | if __name__ == "__main__":
105 | parser = argparse.ArgumentParser(description='Compress video for speed-up')
106 | parser.add_argument('--input_root', type=str, help='input root')
107 | parser.add_argument('--output_root', type=str, help='output root')
108 | args = parser.parse_args()
109 |
110 | input_root = args.input_root
111 | output_root = args.output_root
112 |
113 | assert input_root != output_root
114 |
115 | if not os.path.exists(output_root):
116 | os.makedirs(output_root, exist_ok=True)
117 |
118 | input_video_path_list, output_video_path_list = prepare_input_output_pairs(input_root, output_root)
119 |
120 | print("Total video need to process: {}".format(len(input_video_path_list)))
121 | num_works = cpu_count()
122 | print("Begin with {}-core logical processor.".format(num_works))
123 |
124 | pool = Pool(num_works)
125 | data_dict_list = pool.map(compress,
126 | [(input_video_path, output_video_path) for
127 | input_video_path, output_video_path in
128 | zip(input_video_path_list, output_video_path_list)])
129 | pool.close()
130 | pool.join()
131 |
132 | print("Compress finished, wait for checking files...")
133 | for input_video_path, output_video_path in zip(input_video_path_list, output_video_path_list):
134 | if os.path.exists(input_video_path):
135 | if os.path.exists(output_video_path) is False or os.path.getsize(output_video_path) < 1.:
136 | print("convert fail: {}".format(output_video_path))
--------------------------------------------------------------------------------
/docs/4_sr.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/4_sr.mp4
--------------------------------------------------------------------------------
/docs/barbie2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/barbie2.mp4
--------------------------------------------------------------------------------
/docs/fish.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/fish.gif
--------------------------------------------------------------------------------
/docs/fish.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/fish.jpg
--------------------------------------------------------------------------------
/docs/fish_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/fish_mask.png
--------------------------------------------------------------------------------
/docs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/framework.png
--------------------------------------------------------------------------------
/docs/girl5.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/girl5.mp4
--------------------------------------------------------------------------------
/docs/labelme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/labelme.png
--------------------------------------------------------------------------------
/docs/pig0.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/pig0.mp4
--------------------------------------------------------------------------------
/docs/qingming2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/qingming2.gif
--------------------------------------------------------------------------------
/docs/qingming2_label.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/qingming2_label.jpg
--------------------------------------------------------------------------------
/docs/sample_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_1.gif
--------------------------------------------------------------------------------
/docs/sample_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_1.png
--------------------------------------------------------------------------------
/docs/sample_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_2.gif
--------------------------------------------------------------------------------
/docs/sample_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_2.png
--------------------------------------------------------------------------------
/docs/sample_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_3.gif
--------------------------------------------------------------------------------
/docs/sample_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/docs/sample_3.png
--------------------------------------------------------------------------------
/example/barbie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/barbie.jpg
--------------------------------------------------------------------------------
/example/barbie2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/barbie2.jpg
--------------------------------------------------------------------------------
/example/deepspeed.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | gradient_accumulation_steps: 1
5 | offload_optimizer_device: cpu
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | gpu_ids: 0,1,2,3
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: fp16
15 | num_machines: 1
16 | num_processes: 4
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/example/example_padded_rgba_pngs/apple.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_padded_rgba_pngs/apple.png
--------------------------------------------------------------------------------
/example/example_padded_rgba_pngs/put rgba images here for train_transparent_i2v_stage2.py.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_padded_rgba_pngs/put rgba images here for train_transparent_i2v_stage2.py.txt
--------------------------------------------------------------------------------
/example/example_padded_rgba_pngs/ziyan0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_padded_rgba_pngs/ziyan0.png
--------------------------------------------------------------------------------
/example/example_rgba_video_results/animated rgba results for our transparent unet.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/animated rgba results for our transparent unet.txt
--------------------------------------------------------------------------------
/example/example_rgba_video_results/apple/decoded_alpha.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/apple/decoded_alpha.webp
--------------------------------------------------------------------------------
/example/example_rgba_video_results/apple/decoded_rgba.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/apple/decoded_rgba.webp
--------------------------------------------------------------------------------
/example/example_rgba_video_results/ziyan0/decoded_alpha.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/ziyan0/decoded_alpha.webp
--------------------------------------------------------------------------------
/example/example_rgba_video_results/ziyan0/decoded_rgba.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/example_rgba_video_results/ziyan0/decoded_rgba.webp
--------------------------------------------------------------------------------
/example/fish1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/fish1.jpg
--------------------------------------------------------------------------------
/example/fish1_label.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/fish1_label.jpg
--------------------------------------------------------------------------------
/example/girl5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/girl5.jpg
--------------------------------------------------------------------------------
/example/hulu2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/hulu2.jpg
--------------------------------------------------------------------------------
/example/hulu3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/hulu3.jpg
--------------------------------------------------------------------------------
/example/layerdiffuse_stage2_384.yaml:
--------------------------------------------------------------------------------
1 | # Pretrained diffusers model path.
2 | transparent_unet_pretrained_model_path: "./output/latent/transparent_unet"
3 | transparent_VAE_pretrained_model_path: "./output/latent/transparent_VAE"
4 |
5 | motion_mask: True
6 | motion_strength: True
7 | in_channels: 5 # 5 or 9
8 |
9 | # The folder where your training outputs will be placed.
10 | output_dir: "output/stage_2_eval"
11 |
12 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise
13 | # If this is enabled, rescale_schedule will be disabled.
14 | offset_noise_strength: 0.1
15 | use_offset_noise: False
16 |
17 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf
18 | # If this is enabled, offset noise will be disabled.
19 | rescale_schedule: True
20 |
21 | # When True, this extends all items in all enabled datasets to the highest length.
22 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200.
23 | extend_dataset: False
24 |
25 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD.
26 | # The latents will be saved under your training folder, and loaded automatically for training.
27 | # This both saves memory and speeds up training and takes very little disk space.
28 | cache_latents: False
29 |
30 | # If you have cached latents set to `True` and have a directory of cached latents,
31 | # you can skip the caching process and load previously saved ones.
32 | cached_latent_dir: null #/path/to/cached_latents
33 |
34 | # Train the text encoder for the model. LoRA Training overrides this setting.
35 | train_text_encoder: False
36 |
37 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension)
38 | # This is the first, original implementation of LoRA by cloneofsimo.
39 | # Use this version if you want to maintain compatibility to the original version.
40 |
41 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension)
42 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here.
43 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv).
44 | # This means that there isn't an extra function when doing low ranking adaption.
45 | # It solely saves the weight differential between the initialized weights and updates.
46 |
47 | # "cloneofsimo" or "stable_lora"
48 | lora_version: "cloneofsimo"
49 |
50 | # Use LoRA for the UNET model.
51 | use_unet_lora: False
52 |
53 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained.
54 | use_text_lora: False
55 |
56 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
57 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
58 | lora_unet_dropout: 0.1
59 |
60 | lora_text_dropout: 0.1
61 |
62 | # https://github.com/kabachuha/sd-webui-text2video
63 | # This saves a LoRA that is compatible with the text2video webui extension.
64 | # It only works when the lora version is 'stable_lora'.
65 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation.
66 | save_lora_for_webui: True
67 |
68 | # The LoRA file will be converted to a different format to be compatible with the webui extension.
69 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model
70 | # when this version is set to False
71 | only_lora_for_webui: False
72 |
73 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training.
74 | # The only time you want this off is if you're doing full LoRA training.
75 | save_pretrained_model: True
76 |
77 | # The modules to use for LoRA. Different from 'trainable_modules'.
78 | unet_lora_modules:
79 | - "UNet3DConditionModel"
80 | #- "ResnetBlock2D"
81 | #- "TransformerTemporalModel"
82 | #- "Transformer2DModel"
83 | #- "CrossAttention"
84 | #- "Attention"
85 | #- "GEGLU"
86 | #- "TemporalConvLayer"
87 |
88 | # The modules to use for LoRA. Different from `trainable_text_modules`.
89 | text_encoder_lora_modules:
90 | - "CLIPEncoderLayer"
91 | #- "CLIPAttention"
92 |
93 | # The rank for LoRA training. With ModelScope, the maximum should be 1024.
94 | # VRAM increases with higher rank, lower when decreased.
95 | lora_rank: 16
96 |
97 | # You can train multiple datasets at once. They will be joined together for training.
98 | # Simply remove the line you don't need, or keep them all for mixed training.
99 |
100 | # 'image': A folder of images and captions (.txt)
101 | # 'folder': A folder a videos and captions (.txt)
102 | # 'json': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor
103 | # 'video_json': a video foler and a json caption file
104 | # 'single_video': A single video file.mp4 and text prompt
105 | dataset_types:
106 | #- 'single_video'
107 | #- 'folder'
108 | # - 'image'
109 | - 'video_blip'
110 | # - 'video_json'
111 |
112 | # Training data parameters
113 | train_data:
114 | width: 384
115 | height: 384
116 | use_bucketing: False
117 | return_mask: True
118 | return_motion: True
119 | sample_start_idx: 0
120 | fps: 8
121 | n_sample_frames: 8
122 |
123 | json_path: ''
124 |
125 |
126 | # Validation data parameters.
127 | validation_data:
128 |
129 | # A custom prompt that is different from your training dataset.
130 | prompt: ""
131 |
132 | prompt_image: ""
133 |
134 | # Whether or not to sample preview during training (Requires more VRAM).
135 | sample_preview: True
136 |
137 | # The number of frames to sample during validation.
138 | num_frames: 8
139 |
140 | # Height and width of validation sample.
141 | width: 384
142 | height: 384
143 |
144 | # Number of inference steps when generating the video.
145 | num_inference_steps: 25
146 |
147 | # CFG scale
148 | guidance_scale: 9
149 |
150 | # Learning rate for AdamW
151 | learning_rate: 3.0e-05
152 | lr_scheduler: "cosine"
153 | lr_warmup_steps: 20
154 | # Weight decay. Higher = more regularization. Lower = closer to dataset.
155 | adam_weight_decay: 0
156 |
157 | # Optimizer parameters for the UNET. Overrides base learning rate parameters.
158 | extra_unet_params: null
159 | #learning_rate: 1e-5
160 | #adam_weight_decay: 1e-4
161 |
162 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters.
163 | extra_text_encoder_params: null
164 | #learning_rate: 1e-4
165 | #adam_weight_decay: 0.2
166 |
167 | # How many batches to train. Not to be confused with video frames.
168 | train_batch_size: 8
169 | image_batch_size: 48
170 | gradient_accumulation_steps: 4
171 | # Maximum number of train steps. Model is saved after training.
172 | max_train_steps: 2000
173 |
174 | # Saves a model every nth step.
175 | checkpointing_steps: 200
176 |
177 | # How many steps to do for validation if sample_preview is enabled.
178 | validation_steps: 50
179 |
180 | # Which modules we want to unfreeze for the UNET. Advanced usage.
181 | trainable_modules:
182 | - "all"
183 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions"
184 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1
185 | - "attn1"
186 | - ".attentions"
187 |
188 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1
189 | - "attn2"
190 | - "conv_in"
191 |
192 | # Convolution networks that hold temporal information. Activates for spatial and temporal dimensions if n_sample_frames > 1
193 | - "temp_conv"
194 | - "motion"
195 |
196 |
197 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage.
198 | trainable_text_modules: null
199 |
200 | # Seed for validation.
201 | seed: null
202 |
203 | # Whether or not we want to use mixed precision with accelerate
204 | mixed_precision: "fp16"
205 |
206 | # This seems to be incompatible at the moment.
207 | use_8bit_adam: False
208 |
209 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
210 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
211 | gradient_checkpointing: True
212 | text_encoder_gradient_checkpointing: False
213 |
214 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
215 | enable_xformers_memory_efficient_attention: False
216 |
217 | # Use scaled dot product attention (Only available with >= Torch 2.0)
218 | enable_torch_2_attn: True
219 |
--------------------------------------------------------------------------------
/example/pig0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/pig0.jpg
--------------------------------------------------------------------------------
/example/pig0_label.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/pig0_label.jpg
--------------------------------------------------------------------------------
/example/qingming2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/qingming2.jpg
--------------------------------------------------------------------------------
/example/qingming2_label.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/example/qingming2_label.jpg
--------------------------------------------------------------------------------
/example/train_mask_motion.yaml:
--------------------------------------------------------------------------------
1 | # Pretrained diffusers model path.
2 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main
3 | pretrained_model_path: "output/latent/animate_anything_512_v1.02"
4 | #pretrained_model_path: "output/latent/train_4fps"
5 | #pretrained_model_path: "/data/llm/zeroscope_v2_576w"
6 |
7 | motion_mask: True
8 | motion_strength: True
9 |
10 | # The folder where your training outputs will be placed.
11 | output_dir: "./output/latent"
12 |
13 | # You can train multiple datasets at once. They will be joined together for training.
14 | # Simply remove the line you don't need, or keep them all for mixed training.
15 |
16 | # 'image': A folder of images and captions (.txt)
17 | # 'folder': A folder a videos and captions (.txt)
18 | # 'json': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor
19 | # 'video_json': a video foler and a json caption file
20 | # 'single_video': A single video file.mp4 and text prompt
21 | dataset_types:
22 | #- 'single_video'
23 | #- 'folder'
24 | #- 'image'
25 | - 'video_blip'
26 | #- 'video_json'
27 |
28 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise
29 | # If this is enabled, rescale_schedule will be disabled.
30 | offset_noise_strength: 0.1
31 | use_offset_noise: False
32 |
33 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf
34 | # If this is enabled, offset noise will be disabled.
35 | rescale_schedule: True
36 |
37 | # When True, this extends all items in all enabled datasets to the highest length.
38 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200.
39 | extend_dataset: False
40 |
41 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD.
42 | # The latents will be saved under your training folder, and loaded automatically for training.
43 | # This both saves memory and speeds up training and takes very little disk space.
44 | cache_latents: False
45 |
46 | # If you have cached latents set to `True` and have a directory of cached latents,
47 | # you can skip the caching process and load previously saved ones.
48 | cached_latent_dir: null #/path/to/cached_latents
49 |
50 | # Train the text encoder for the model. LoRA Training overrides this setting.
51 | train_text_encoder: False
52 |
53 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension)
54 | # This is the first, original implementation of LoRA by cloneofsimo.
55 | # Use this version if you want to maintain compatibility to the original version.
56 |
57 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension)
58 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here.
59 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv).
60 | # This means that there isn't an extra function when doing low ranking adaption.
61 | # It solely saves the weight differential between the initialized weights and updates.
62 |
63 | # "cloneofsimo" or "stable_lora"
64 | lora_version: "cloneofsimo"
65 |
66 | # Use LoRA for the UNET model.
67 | use_unet_lora: False
68 |
69 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained.
70 | use_text_lora: False
71 |
72 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
73 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
74 | lora_unet_dropout: 0.1
75 |
76 | lora_text_dropout: 0.1
77 |
78 | # https://github.com/kabachuha/sd-webui-text2video
79 | # This saves a LoRA that is compatible with the text2video webui extension.
80 | # It only works when the lora version is 'stable_lora'.
81 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation.
82 | save_lora_for_webui: True
83 |
84 | # The LoRA file will be converted to a different format to be compatible with the webui extension.
85 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model
86 | # when this version is set to False
87 | only_lora_for_webui: False
88 |
89 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training.
90 | # The only time you want this off is if you're doing full LoRA training.
91 | save_pretrained_model: True
92 |
93 | # The modules to use for LoRA. Different from 'trainable_modules'.
94 | unet_lora_modules:
95 | - "UNet3DConditionModel"
96 | #- "ResnetBlock2D"
97 | #- "TransformerTemporalModel"
98 | #- "Transformer2DModel"
99 | #- "CrossAttention"
100 | #- "Attention"
101 | #- "GEGLU"
102 | #- "TemporalConvLayer"
103 |
104 | # The modules to use for LoRA. Different from `trainable_text_modules`.
105 | text_encoder_lora_modules:
106 | - "CLIPEncoderLayer"
107 | #- "CLIPAttention"
108 |
109 | # The rank for LoRA training. With ModelScope, the maximum should be 1024.
110 | # VRAM increases with higher rank, lower when decreased.
111 | lora_rank: 16
112 |
113 | # Training data parameters
114 | train_data:
115 | width: 512
116 | height: 512
117 | use_bucketing: False
118 | return_mask: True
119 | return_motion: True
120 | sample_start_idx: 1
121 | fps: 8
122 | n_sample_frames: 16
123 | json_path: '/webvid/animation0.json'
124 |
125 | # Validation data parameters.
126 | validation_data:
127 |
128 | # A custom prompt that is different from your training dataset.
129 | prompt: "a girl moves hands"
130 |
131 | prompt_image: "output/example/girl4.jpg"
132 |
133 | # Whether or not to sample preview during training (Requires more VRAM).
134 | sample_preview: True
135 |
136 | # The number of frames to sample during validation.
137 | num_frames: 16
138 |
139 | # Height and width of validation sample.
140 | width: 512
141 | height: 512
142 |
143 | # Number of inference steps when generating the video.
144 | num_inference_steps: 25
145 |
146 | # CFG scale
147 | guidance_scale: 9
148 |
149 | # Learning rate for AdamW
150 | learning_rate: 5.0e-06
151 |
152 | # Weight decay. Higher = more regularization. Lower = closer to dataset.
153 | adam_weight_decay: 0
154 |
155 | # Optimizer parameters for the UNET. Overrides base learning rate parameters.
156 | extra_unet_params: null
157 | #learning_rate: 1e-5
158 | #adam_weight_decay: 1e-4
159 |
160 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters.
161 | extra_text_encoder_params: null
162 | #learning_rate: 1e-4
163 | #adam_weight_decay: 0.2
164 |
165 | # How many batches to train. Not to be confused with video frames.
166 | train_batch_size: 8
167 | # Maximum number of train steps. Model is saved after training.
168 | max_train_steps: 5000
169 |
170 | # Saves a model every nth step.
171 | checkpointing_steps: 1000
172 |
173 | # How many steps to do for validation if sample_preview is enabled.
174 | validation_steps: 200
175 |
176 | # Which modules we want to unfreeze for the UNET. Advanced usage.
177 | trainable_modules:
178 | #- "all"
179 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions"
180 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1
181 | - "attn1"
182 | - ".attentions"
183 |
184 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1
185 | - "attn2"
186 | - "conv_in"
187 |
188 | # Convolution networks that hold temporal information. Activates for spatial and temporal dimensions if n_sample_frames > 1
189 | - "temp_conv"
190 | - "motion"
191 |
192 |
193 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage.
194 | trainable_text_modules: null
195 |
196 | # Seed for validation.
197 | seed: null
198 |
199 | # Whether or not we want to use mixed precision with accelerate
200 | mixed_precision: "fp16"
201 |
202 | # This seems to be incompatible at the moment.
203 | use_8bit_adam: False
204 |
205 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
206 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
207 | gradient_checkpointing: True
208 | text_encoder_gradient_checkpointing: False
209 |
210 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
211 | enable_xformers_memory_efficient_attention: False
212 |
213 | # Use scaled dot product attention (Only available with >= Torch 2.0)
214 | enable_torch_2_attn: True
215 |
--------------------------------------------------------------------------------
/example/train_mask_motion_lora.yaml:
--------------------------------------------------------------------------------
1 | # running scripts:
2 | # accelerate launch --config_file example/deepspeed.yaml train_lora.py --config example/train_mask_motion_lora.yaml
3 | # python train_lora.py --config example/train_mask_motion_lora.yaml --eval
4 |
5 | # Pretrained diffusers model path.
6 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main
7 | pretrained_model_path: "output/latent/animate_anything_512_v1.02"
8 |
9 | # pretrained lora path
10 | # lora_path is only valid during eval (--eval).
11 | # lora module is saved to {output_dir}/train_{datetime}/{checkpoint}/lora by default during training
12 | lora_path: "/path/to/your_lora_module"
13 |
14 | motion_mask: True
15 | motion_strength: True
16 |
17 | # The folder where your training outputs will be placed.
18 | output_dir: "./output/latent"
19 |
20 | # You can train multiple datasets at once. They will be joined together for training.
21 | # Simply remove the line you don't need, or keep them all for mixed training.
22 |
23 | # 'image': A folder of images and captions (.txt)
24 | # 'folder': A folder a videos and captions (.txt)
25 | # 'json': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor
26 | # 'video_json': a video foler and a json caption file
27 | # 'single_video': A single video file.mp4 and text prompt
28 | dataset_types:
29 | #- 'single_video'
30 | #- 'folder'
31 | #- 'image'
32 | - 'video_blip'
33 | #- 'video_json'
34 |
35 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise
36 | # If this is enabled, rescale_schedule will be disabled.
37 | offset_noise_strength: 0.1
38 | use_offset_noise: False
39 |
40 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf
41 | # If this is enabled, offset noise will be disabled.
42 | rescale_schedule: True
43 |
44 | # When True, this extends all items in all enabled datasets to the highest length.
45 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200.
46 | extend_dataset: False
47 |
48 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD.
49 | # The latents will be saved under your training folder, and loaded automatically for training.
50 | # This both saves memory and speeds up training and takes very little disk space.
51 | cache_latents: False
52 |
53 | # If you have cached latents set to `True` and have a directory of cached latents,
54 | # you can skip the caching process and load previously saved ones.
55 | cached_latent_dir: null #/path/to/cached_latents
56 |
57 | # Train the text encoder for the model. LoRA Training overrides this setting.
58 | train_text_encoder: False
59 |
60 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension)
61 | # This is the first, original implementation of LoRA by cloneofsimo.
62 | # Use this version if you want to maintain compatibility to the original version.
63 |
64 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension)
65 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here.
66 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv).
67 | # This means that there isn't an extra function when doing low ranking adaption.
68 | # It solely saves the weight differential between the initialized weights and updates.
69 |
70 | # "cloneofsimo" or "stable_lora"
71 | lora_version: "cloneofsimo"
72 |
73 | # Use LoRA for the UNET model.
74 | use_unet_lora: True
75 |
76 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained.
77 | use_text_lora: False
78 |
79 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
80 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
81 | lora_unet_dropout: 0.1
82 |
83 | lora_text_dropout: 0.1
84 |
85 | # https://github.com/kabachuha/sd-webui-text2video
86 | # This saves a LoRA that is compatible with the text2video webui extension.
87 | # It only works when the lora version is 'stable_lora'.
88 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation.
89 | save_lora_for_webui: True
90 |
91 | # The LoRA file will be converted to a different format to be compatible with the webui extension.
92 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model
93 | # when this version is set to False
94 | only_lora_for_webui: False
95 |
96 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training.
97 | # The only time you want this off is if you're doing full LoRA training.
98 | save_pretrained_model: True
99 |
100 | # The modules to use for LoRA. Different from 'trainable_modules'.
101 | unet_lora_modules:
102 | - "UNet3DConditionModel"
103 | #- "ResnetBlock2D"
104 | #- "TransformerTemporalModel"
105 | #- "Transformer2DModel"
106 | #- "CrossAttention"
107 | #- "Attention"
108 | #- "GEGLU"
109 | #- "TemporalConvLayer"
110 |
111 | # The modules to use for LoRA. Different from `trainable_text_modules`.
112 | text_encoder_lora_modules:
113 | - "CLIPEncoderLayer"
114 | #- "CLIPAttention"
115 |
116 | # The rank for LoRA training. With ModelScope, the maximum should be 1024.
117 | # VRAM increases with higher rank, lower when decreased.
118 | lora_rank: 16
119 |
120 | # Training data parameters
121 | train_data:
122 | width: 512
123 | height: 512
124 | use_bucketing: False
125 | return_mask: True
126 | return_motion: True
127 | sample_start_idx: 1
128 | fps: 8
129 | n_sample_frames: 16
130 | json_path: '/webvid/animation0.json'
131 |
132 | # Validation data parameters.
133 | validation_data:
134 |
135 | # A custom prompt that is different from your training dataset.
136 | prompt: "a girl smiling"
137 |
138 | prompt_image: "example/barbie.jpg"
139 |
140 | # Whether or not to sample preview during training (Requires more VRAM).
141 | sample_preview: True
142 |
143 | # The number of frames to sample during validation.
144 | num_frames: 16
145 |
146 | # Height and width of validation sample.
147 | width: 512
148 | height: 512
149 |
150 | # Number of inference steps when generating the video.
151 | num_inference_steps: 25
152 |
153 | # CFG scale
154 | guidance_scale: 9
155 |
156 | # Learning rate for AdamW
157 | learning_rate: 5.0e-06
158 |
159 | # Weight decay. Higher = more regularization. Lower = closer to dataset.
160 | adam_weight_decay: 0
161 |
162 | # Optimizer parameters for the UNET. Overrides base learning rate parameters.
163 | extra_unet_params: null
164 | #learning_rate: 1e-5
165 | #adam_weight_decay: 1e-4
166 |
167 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters.
168 | extra_text_encoder_params: null
169 | #learning_rate: 1e-4
170 | #adam_weight_decay: 0.2
171 |
172 | # How many batches to train. Not to be confused with video frames.
173 | train_batch_size: 4
174 | # Maximum number of train steps. Model is saved after training.
175 | max_train_steps: 1000
176 |
177 | # Saves a model every nth step.
178 | checkpointing_steps: 100
179 |
180 | # How many steps to do for validation if sample_preview is enabled.
181 | validation_steps: 100
182 |
183 | # Which modules we want to unfreeze for the UNET. Advanced usage.
184 | # trainable_modules:
185 | # - "None"
186 |
187 |
188 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage.
189 | trainable_text_modules: null
190 |
191 | # Seed for validation.
192 | seed: null
193 |
194 | # Whether or not we want to use mixed precision with accelerate
195 | mixed_precision: "fp16"
196 |
197 | # This seems to be incompatible at the moment.
198 | use_8bit_adam: False
199 |
200 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
201 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
202 | gradient_checkpointing: True
203 | text_encoder_gradient_checkpointing: False
204 |
205 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
206 | enable_xformers_memory_efficient_attention: False
207 |
208 | # Use scaled dot product attention (Only available with >= Torch 2.0)
209 | enable_torch_2_attn: True
210 |
--------------------------------------------------------------------------------
/example/train_svd.yaml:
--------------------------------------------------------------------------------
1 | # Pretrained diffusers model path.
2 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main
3 | pretrained_model_path: "/webvid/llm/stable-video-diffusion-img2vid"
4 |
5 |
6 | motion_mask: False
7 | motion_strength: False
8 |
9 | # The folder where your training outputs will be placed.
10 | output_dir: "./output/svd"
11 |
12 | # You can train multiple datasets at once. They will be joined together for training.
13 | # Simply remove the line you don't need, or keep them all for mixed training.
14 |
15 | # 'image': A folder of images and captions (.txt)
16 | # 'folder': A folder a videos and captions (.txt)
17 | # 'video_blip': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor
18 | # 'video_json': a video foler and a json caption file
19 | # 'single_video': A single video file.mp4 and text prompt
20 | dataset_types:
21 | #- 'single_video'
22 | #- 'folder'
23 | #- 'image'
24 | - 'video_blip'
25 | #- 'video_json'
26 |
27 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise
28 | # If this is enabled, rescale_schedule will be disabled.
29 | offset_noise_strength: 0.1
30 | use_offset_noise: False
31 |
32 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf
33 | # If this is enabled, offset noise will be disabled.
34 | rescale_schedule: False
35 |
36 | # When True, this extends all items in all enabled datasets to the highest length.
37 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200.
38 | extend_dataset: False
39 |
40 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD.
41 | # The latents will be saved under your training folder, and loaded automatically for training.
42 | # This both saves memory and speeds up training and takes very little disk space.
43 | cache_latents: False
44 |
45 | # If you have cached latents set to `True` and have a directory of cached latents,
46 | # you can skip the caching process and load previously saved ones.
47 | cached_latent_dir: null #/path/to/cached_latents
48 |
49 | # Train the text encoder for the model. LoRA Training overrides this setting.
50 | train_text_encoder: False
51 |
52 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension)
53 | # This is the first, original implementation of LoRA by cloneofsimo.
54 | # Use this version if you want to maintain compatibility to the original version.
55 |
56 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension)
57 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here.
58 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv).
59 | # This means that there isn't an extra function when doing low ranking adaption.
60 | # It solely saves the weight differential between the initialized weights and updates.
61 |
62 | # "cloneofsimo" or "stable_lora"
63 | lora_version: "cloneofsimo"
64 |
65 | # Use LoRA for the UNET model.
66 | use_unet_lora: False
67 |
68 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained.
69 | use_text_lora: False
70 |
71 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
72 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
73 | lora_unet_dropout: 0.1
74 |
75 | lora_text_dropout: 0.1
76 |
77 | # https://github.com/kabachuha/sd-webui-text2video
78 | # This saves a LoRA that is compatible with the text2video webui extension.
79 | # It only works when the lora version is 'stable_lora'.
80 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation.
81 | save_lora_for_webui: True
82 |
83 | # The LoRA file will be converted to a different format to be compatible with the webui extension.
84 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model
85 | # when this version is set to False
86 | only_lora_for_webui: False
87 |
88 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training.
89 | # The only time you want this off is if you're doing full LoRA training.
90 | save_pretrained_model: True
91 |
92 | # The modules to use for LoRA. Different from 'trainable_modules'.
93 | unet_lora_modules:
94 | - "UNet3DConditionModel"
95 | #- "ResnetBlock2D"
96 | #- "TransformerTemporalModel"
97 | #- "Transformer2DModel"
98 | #- "CrossAttention"
99 | #- "Attention"
100 | #- "GEGLU"
101 | #- "TemporalConvLayer"
102 |
103 | # The modules to use for LoRA. Different from `trainable_text_modules`.
104 | text_encoder_lora_modules:
105 | - "CLIPEncoderLayer"
106 | #- "CLIPAttention"
107 |
108 | # The rank for LoRA training. With ModelScope, the maximum should be 1024.
109 | # VRAM increases with higher rank, lower when decreased.
110 | lora_rank: 16
111 |
112 | # Training data parameters
113 | train_data:
114 |
115 | # The width and height in which you want your training data to be resized to.
116 | width: 512
117 | height: 512
118 |
119 | # This will find the closest aspect ratio to your input width and height.
120 | # For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256
121 | use_bucketing: False
122 | return_mask: True
123 | return_motion: True
124 |
125 | # The start frame index where your videos should start (Leave this at one for json and folder based training).
126 | sample_start_idx: 1
127 |
128 | # Used for 'folder'. The rate at which your frames are sampled. Does nothing for 'json' and 'single_video' dataset.
129 | # high fps, lower frame step, move slowly
130 | fps: 7
131 |
132 | # For 'single_video' and 'json'. The number of frames to "step" (1,2,3,4) (frame_step=2) -> (1,3,5,7, ...).
133 | frame_step: 1
134 |
135 | # The number of frames to sample. The higher this number, the higher the VRAM (acts similar to batch size).
136 | n_sample_frames: 7
137 |
138 | # 'single_video'
139 | single_video_path: "/data/datasets/animal_kingdom/video_grounding/dataset/AADJBFXO.mp4"
140 |
141 | # The prompt when using a a single video file
142 | single_video_prompt: "a bird"
143 |
144 | # Fallback prompt if caption cannot be read. Enabled for 'image' and 'folder'.
145 | fallback_prompt: ''
146 |
147 | # 'folder'
148 | #path: "/data2/webvid/data/videos/004151_004200"
149 | path: "/data/datasets/msvd/videos_mp4"
150 |
151 | # 'json'
152 | json_path: '/webvid/animation1.json'
153 |
154 | # 'image'
155 | image_dir: '/vlp/datasets/images/coco'
156 | image_json: '/vlp/datasets/images/coco/coco_karpathy_train.json'
157 |
158 | video_dir: '/webvid/webvid/data/videos'
159 | video_json: '/webvid/webvid/data/40K.json'
160 | # The prompt for all image files. Leave blank to use caption files (.txt)
161 | single_img_prompt: ""
162 |
163 |
164 | # Validation data parameters.
165 | validation_data:
166 |
167 | # A custom prompt that is different from your training dataset.
168 | prompt: "a girl moves body"
169 |
170 | prompt_image: "output/example/girl4.jpg"
171 |
172 | # Whether or not to sample preview during training (Requires more VRAM).
173 | sample_preview: True
174 |
175 | # The number of frames to sample during validation.
176 | num_frames: 14
177 |
178 | # Height and width of validation sample.
179 | width: 512
180 | height: 512
181 |
182 | # Number of inference steps when generating the video.
183 | num_inference_steps: 25
184 |
185 | # CFG scale
186 | guidance_scale: 9
187 |
188 | # fps
189 | fps: 7
190 |
191 | # The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
192 | motion_bucket_id: 127
193 |
194 | # The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency between frames,
195 | # but also the higher the memory consumption. By default, the decoder will decode all frames at once for maximal quality.
196 | # Reduce `decode_chunk_size` to reduce memory usage.
197 | decode_chunk_size: 7
198 |
199 | # Learning rate for AdamW
200 | learning_rate: 5e-6
201 |
202 | # Weight decay. Higher = more regularization. Lower = closer to dataset.
203 | adam_weight_decay: 0
204 |
205 | # Optimizer parameters for the UNET. Overrides base learning rate parameters.
206 | extra_unet_params: null
207 | #learning_rate: 1e-5
208 | #adam_weight_decay: 1e-4
209 |
210 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters.
211 | extra_text_encoder_params: null
212 | #learning_rate: 1e-4
213 | #adam_weight_decay: 0.2
214 |
215 | # How many batches to train. Not to be confused with video frames.
216 | train_batch_size: 1
217 | # Maximum number of train steps. Model is saved after training.
218 | max_train_steps: 10000
219 |
220 | # Saves a model every nth step.
221 | checkpointing_steps: 2500
222 |
223 | # How many steps to do for validation if sample_preview is enabled.
224 | validation_steps: 300
225 |
226 | # Which modules we want to unfreeze for the UNET. Advanced usage.
227 | trainable_modules:
228 | - "all"
229 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions"
230 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1
231 | - "attn1"
232 | #- ".attentions"
233 |
234 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1
235 | - "attn2"
236 | - "conv_in"
237 |
238 | # Convolution networks that hold temporal information. Activates for spatial and temporal dimensions if n_sample_frames > 1
239 | - "temp_conv"
240 | - "motion"
241 |
242 |
243 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage.
244 | trainable_text_modules: null
245 |
246 | # Seed for validation.
247 | seed: 6
248 |
249 | # Whether or not we want to use mixed precision with accelerate
250 | mixed_precision: "fp16"
251 |
252 | # This seems to be incompatible at the moment.
253 | use_8bit_adam: False
254 |
255 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
256 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
257 | gradient_checkpointing: True
258 | text_encoder_gradient_checkpointing: False
259 |
260 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
261 | enable_xformers_memory_efficient_attention: False
262 |
263 | # Use scaled dot product attention (Only available with >= Torch 2.0)
264 | enable_torch_2_attn: True
265 |
--------------------------------------------------------------------------------
/example/train_svd_mask.yaml:
--------------------------------------------------------------------------------
1 | # Pretrained diffusers model path.
2 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main
3 | pretrained_model_path: "/webvid/llm/stable-video-diffusion-img2vid-mask"
4 |
5 | motion_mask: True
6 | motion_strength: False
7 |
8 | # The folder where your training outputs will be placed.
9 | output_dir: "./output/svd"
10 |
11 | # You can train multiple datasets at once. They will be joined together for training.
12 | # Simply remove the line you don't need, or keep them all for mixed training.
13 |
14 | # 'image': A folder of images and captions (.txt)
15 | # 'folder': A folder a videos and captions (.txt)
16 | # 'video_blip': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor
17 | # 'video_json': a video foler and a json caption file
18 | # 'single_video': A single video file.mp4 and text prompt
19 | dataset_types:
20 | #- 'single_video'
21 | #- 'folder'
22 | #- 'image'
23 | - 'video_blip'
24 | #- 'video_json'
25 |
26 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise
27 | # If this is enabled, rescale_schedule will be disabled.
28 | offset_noise_strength: 0.1
29 | use_offset_noise: False
30 |
31 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf
32 | # If this is enabled, offset noise will be disabled.
33 | rescale_schedule: False
34 |
35 | # When True, this extends all items in all enabled datasets to the highest length.
36 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200.
37 | extend_dataset: False
38 |
39 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD.
40 | # The latents will be saved under your training folder, and loaded automatically for training.
41 | # This both saves memory and speeds up training and takes very little disk space.
42 | cache_latents: False
43 |
44 | # If you have cached latents set to `True` and have a directory of cached latents,
45 | # you can skip the caching process and load previously saved ones.
46 | cached_latent_dir: null #/path/to/cached_latents
47 |
48 | # Train the text encoder for the model. LoRA Training overrides this setting.
49 | train_text_encoder: False
50 |
51 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension)
52 | # This is the first, original implementation of LoRA by cloneofsimo.
53 | # Use this version if you want to maintain compatibility to the original version.
54 |
55 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension)
56 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here.
57 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv).
58 | # This means that there isn't an extra function when doing low ranking adaption.
59 | # It solely saves the weight differential between the initialized weights and updates.
60 |
61 | # "cloneofsimo" or "stable_lora"
62 | lora_version: "cloneofsimo"
63 |
64 | # Use LoRA for the UNET model.
65 | use_unet_lora: False
66 |
67 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained.
68 | use_text_lora: False
69 |
70 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
71 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
72 | lora_unet_dropout: 0.1
73 |
74 | lora_text_dropout: 0.1
75 |
76 | # https://github.com/kabachuha/sd-webui-text2video
77 | # This saves a LoRA that is compatible with the text2video webui extension.
78 | # It only works when the lora version is 'stable_lora'.
79 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation.
80 | save_lora_for_webui: True
81 |
82 | # The LoRA file will be converted to a different format to be compatible with the webui extension.
83 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model
84 | # when this version is set to False
85 | only_lora_for_webui: False
86 |
87 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training.
88 | # The only time you want this off is if you're doing full LoRA training.
89 | save_pretrained_model: True
90 |
91 | # The modules to use for LoRA. Different from 'trainable_modules'.
92 | unet_lora_modules:
93 | - "UNet3DConditionModel"
94 | #- "ResnetBlock2D"
95 | #- "TransformerTemporalModel"
96 | #- "Transformer2DModel"
97 | #- "CrossAttention"
98 | #- "Attention"
99 | #- "GEGLU"
100 | #- "TemporalConvLayer"
101 |
102 | # The modules to use for LoRA. Different from `trainable_text_modules`.
103 | text_encoder_lora_modules:
104 | - "CLIPEncoderLayer"
105 | #- "CLIPAttention"
106 |
107 | # The rank for LoRA training. With ModelScope, the maximum should be 1024.
108 | # VRAM increases with higher rank, lower when decreased.
109 | lora_rank: 16
110 |
111 | # Training data parameters
112 | train_data:
113 |
114 | # The width and height in which you want your training data to be resized to.
115 | width: 512
116 | height: 512
117 |
118 | # This will find the closest aspect ratio to your input width and height.
119 | # For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256
120 | use_bucketing: False
121 | return_mask: True
122 | return_motion: True
123 |
124 | # The start frame index where your videos should start (Leave this at one for json and folder based training).
125 | sample_start_idx: 1
126 |
127 | # Used for 'folder'. The rate at which your frames are sampled. Does nothing for 'json' and 'single_video' dataset.
128 | # high fps, lower frame step, move slowly
129 | fps: 7
130 |
131 | # For 'single_video' and 'json'. The number of frames to "step" (1,2,3,4) (frame_step=2) -> (1,3,5,7, ...).
132 | frame_step: 1
133 |
134 | # The number of frames to sample. The higher this number, the higher the VRAM (acts similar to batch size).
135 | n_sample_frames: 12
136 |
137 | # 'single_video'
138 | single_video_path: "/data/datasets/animal_kingdom/video_grounding/dataset/AADJBFXO.mp4"
139 |
140 | # The prompt when using a a single video file
141 | single_video_prompt: "a bird"
142 |
143 | # Fallback prompt if caption cannot be read. Enabled for 'image' and 'folder'.
144 | fallback_prompt: ''
145 |
146 | # 'folder'
147 | #path: "/data2/webvid/data/videos/004151_004200"
148 | path: "/data/datasets/msvd/videos_mp4"
149 |
150 | # 'json'
151 | json_path: '/webvid/animation1.json'
152 |
153 | # 'image'
154 | image_dir: '/vlp/datasets/images/coco'
155 | image_json: '/vlp/datasets/images/coco/coco_karpathy_train.json'
156 |
157 | video_dir: '/webvid/webvid/data/videos'
158 | video_json: '/webvid/webvid/data/40K.json'
159 | # The prompt for all image files. Leave blank to use caption files (.txt)
160 | single_img_prompt: ""
161 |
162 |
163 | # Validation data parameters.
164 | validation_data:
165 |
166 | # A custom prompt that is different from your training dataset.
167 | prompt: "a girl moves body"
168 |
169 | prompt_image: "output/example/fish1.jpg"
170 |
171 | # Whether or not to sample preview during training (Requires more VRAM).
172 | sample_preview: True
173 |
174 | # The number of frames to sample during validation.
175 | num_frames: 14
176 |
177 | # Height and width of validation sample.
178 | width: 512
179 | height: 512
180 |
181 | # Number of inference steps when generating the video.
182 | num_inference_steps: 25
183 |
184 | # CFG scale
185 | guidance_scale: 9
186 |
187 | # fps
188 | fps: 7
189 |
190 | # The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
191 | motion_bucket_id: 127
192 |
193 | # The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency between frames,
194 | # but also the higher the memory consumption. By default, the decoder will decode all frames at once for maximal quality.
195 | # Reduce `decode_chunk_size` to reduce memory usage.
196 | decode_chunk_size: 7
197 |
198 | # Learning rate for AdamW
199 | learning_rate: 2e-5
200 |
201 | # Weight decay. Higher = more regularization. Lower = closer to dataset.
202 | adam_weight_decay: 0
203 |
204 | # Optimizer parameters for the UNET. Overrides base learning rate parameters.
205 | extra_unet_params: null
206 | #learning_rate: 1e-5
207 | #adam_weight_decay: 1e-4
208 |
209 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters.
210 | extra_text_encoder_params: null
211 | #learning_rate: 1e-4
212 | #adam_weight_decay: 0.2
213 |
214 | # How many batches to train. Not to be confused with video frames.
215 | train_batch_size: 3
216 | # Maximum number of train steps. Model is saved after training.
217 | max_train_steps: 20000
218 |
219 | # Saves a model every nth step.
220 | checkpointing_steps: 2500
221 |
222 | # How many steps to do for validation if sample_preview is enabled.
223 | validation_steps: 100
224 |
225 | # Which modules we want to unfreeze for the UNET. Advanced usage.
226 | trainable_modules:
227 | - "all"
228 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions"
229 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1
230 | - "attn1"
231 | - ".attentions"
232 |
233 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1
234 | - "attn2"
235 | - "conv_in"
236 |
237 | # Convolution networks that hold temporal information. Activates for spatial and temporal dimensions if n_sample_frames > 1
238 | - "temp_conv"
239 | - "motion"
240 |
241 |
242 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage.
243 | trainable_text_modules: null
244 |
245 | # Seed for validation.
246 | seed: 6
247 |
248 | # Whether or not we want to use mixed precision with accelerate
249 | mixed_precision: "fp16"
250 |
251 | # This seems to be incompatible at the moment.
252 | use_8bit_adam: False
253 |
254 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
255 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
256 | gradient_checkpointing: True
257 | text_encoder_gradient_checkpointing: False
258 |
259 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
260 | enable_xformers_memory_efficient_attention: False
261 |
262 | # Use scaled dot product attention (Only available with >= Torch 2.0)
263 | enable_torch_2_attn: True
264 |
--------------------------------------------------------------------------------
/example/train_svd_v2v.yaml:
--------------------------------------------------------------------------------
1 | # Pretrained diffusers model path.
2 | #pretrained_model_path: "output/latent/train_mask_motion" #https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/tree/main
3 | pretrained_model_path: "/webvid/llm/stable-video-diffusion-img2vid"
4 |
5 |
6 | motion_mask: True
7 | motion_strength: False
8 |
9 | # The folder where your training outputs will be placed.
10 | output_dir: "./output/svd"
11 |
12 | # You can train multiple datasets at once. They will be joined together for training.
13 | # Simply remove the line you don't need, or keep them all for mixed training.
14 |
15 | # 'image': A folder of images and captions (.txt)
16 | # 'folder': A folder a videos and captions (.txt)
17 | # 'json': The JSON file created with automatic BLIP2 captions using https://github.com/ExponentialML/Video-BLIP2-Preprocessor
18 | # 'video_json': a video foler and a json caption file
19 | # 'single_video': A single video file.mp4 and text prompt
20 | dataset_types:
21 | #- 'single_video'
22 | #- 'folder'
23 | #- 'image'
24 | - 'video_blip'
25 | #- 'video_json'
26 |
27 | # Adds offset noise to training. See https://www.crosslabs.org/blog/diffusion-with-offset-noise
28 | # If this is enabled, rescale_schedule will be disabled.
29 | offset_noise_strength: 0.1
30 | use_offset_noise: False
31 |
32 | # Uses schedule rescale, also known as the "better" offset noise. See https://arxiv.org/pdf/2305.08891.pdf
33 | # If this is enabled, offset noise will be disabled.
34 | rescale_schedule: False
35 |
36 | # When True, this extends all items in all enabled datasets to the highest length.
37 | # For example, if you have 200 videos and 10 images, 10 images will be duplicated to the length of 200.
38 | extend_dataset: False
39 |
40 | # Caches the latents (Frames-Image -> VAE -> Latent) to a HDD or SDD.
41 | # The latents will be saved under your training folder, and loaded automatically for training.
42 | # This both saves memory and speeds up training and takes very little disk space.
43 | cache_latents: False
44 |
45 | # If you have cached latents set to `True` and have a directory of cached latents,
46 | # you can skip the caching process and load previously saved ones.
47 | cached_latent_dir: null #/path/to/cached_latents
48 |
49 | # https://github.com/cloneofsimo/lora (NOT Compatible with webui extension)
50 | # This is the first, original implementation of LoRA by cloneofsimo.
51 | # Use this version if you want to maintain compatibility to the original version.
52 |
53 | # https://github.com/ExponentialML/Stable-LoRA/tree/main (Compatible with webui text2video extension)
54 | # This is an implementation based off of the original LoRA repository by Microsoft, and the default LoRA method here.
55 | # It works a different by using embeddings instead of the intermediate activations (Linear || Conv).
56 | # This means that there isn't an extra function when doing low ranking adaption.
57 | # It solely saves the weight differential between the initialized weights and updates.
58 |
59 | # "cloneofsimo" or "stable_lora"
60 | lora_version: "cloneofsimo"
61 |
62 | # Use LoRA for the UNET model.
63 | use_unet_lora: False
64 |
65 | # Use LoRA for the Text Encoder. If this is set, the text encoder for the model will not be trained.
66 | use_text_lora: False
67 |
68 | # LoRA Dropout. This parameter adds the probability of randomly zeros out elements. Helps prevent overfitting.
69 | # See: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
70 | lora_unet_dropout: 0.1
71 |
72 | lora_text_dropout: 0.1
73 |
74 | # https://github.com/kabachuha/sd-webui-text2video
75 | # This saves a LoRA that is compatible with the text2video webui extension.
76 | # It only works when the lora version is 'stable_lora'.
77 | # This is also a DIFFERENT implementation than Kohya's, so it will NOT work the same implementation.
78 | save_lora_for_webui: True
79 |
80 | # The LoRA file will be converted to a different format to be compatible with the webui extension.
81 | # The difference between this and 'save_lora_for_webui' is that you can continue training a Diffusers pipeline model
82 | # when this version is set to False
83 | only_lora_for_webui: False
84 |
85 | # Choose whether or not ito save the full pretrained model weights for both checkpoints and after training.
86 | # The only time you want this off is if you're doing full LoRA training.
87 | save_pretrained_model: True
88 |
89 | # The modules to use for LoRA. Different from 'trainable_modules'.
90 | unet_lora_modules:
91 | - "UNet3DConditionModel"
92 | #- "ResnetBlock2D"
93 | #- "TransformerTemporalModel"
94 | #- "Transformer2DModel"
95 | #- "CrossAttention"
96 | #- "Attention"
97 | #- "GEGLU"
98 | #- "TemporalConvLayer"
99 |
100 | # The modules to use for LoRA. Different from `trainable_text_modules`.
101 | text_encoder_lora_modules:
102 | - "CLIPEncoderLayer"
103 | #- "CLIPAttention"
104 |
105 | # The rank for LoRA training. With ModelScope, the maximum should be 1024.
106 | # VRAM increases with higher rank, lower when decreased.
107 | lora_rank: 16
108 |
109 | # Training data parameters
110 | train_data:
111 |
112 | # The width and height in which you want your training data to be resized to.
113 | width: 576
114 | height: 1024
115 |
116 | # This will find the closest aspect ratio to your input width and height.
117 | # For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256
118 | use_bucketing: False
119 | return_mask: True
120 | return_motion: True
121 |
122 | # The start frame index where your videos should start (Leave this at one for json and folder based training).
123 | sample_start_idx: 1
124 |
125 | # Used for 'folder'. The rate at which your frames are sampled. Does nothing for 'json' and 'single_video' dataset.
126 | # high fps, lower frame step, move slowly
127 | fps: 7
128 |
129 | # For 'single_video' and 'json'. The number of frames to "step" (1,2,3,4) (frame_step=2) -> (1,3,5,7, ...).
130 | frame_step: 1
131 |
132 | # The number of frames to sample. The higher this number, the higher the VRAM (acts similar to batch size).
133 | n_sample_frames: 14
134 |
135 | # 'single_video'
136 | single_video_path: "/data/datasets/animal_kingdom/video_grounding/dataset/AADJBFXO.mp4"
137 |
138 | # The prompt when using a a single video file
139 | single_video_prompt: "a bird"
140 |
141 | # Fallback prompt if caption cannot be read. Enabled for 'image' and 'folder'.
142 | fallback_prompt: ''
143 |
144 | # 'folder'
145 | #path: "/data2/webvid/data/videos/004151_004200"
146 | path: "/data/datasets/msvd/videos_mp4"
147 |
148 | # 'json'
149 | json_path: '/webvid/animation2.json'
150 |
151 | # 'image'
152 | image_dir: '/vlp/datasets/images/coco'
153 | image_json: '/vlp/datasets/images/coco/coco_karpathy_train.json'
154 |
155 | #video_dir: '/mnt/cap/zuozhuo/webvid/webvid/data/videos'
156 | #video_json: '/mnt/cap/zuozhuo/webvid/webvid/data/1M.json'
157 | # The prompt for all image files. Leave blank to use caption files (.txt)
158 | single_img_prompt: ""
159 |
160 | video_dir: '/webvid/webvid/data/videos'
161 | video_json: '/webvid/webvid/data/40K.json'
162 |
163 | extra_train_data:
164 | - dataset_types:
165 | - video_blip
166 | train_data:
167 | json_path: '/webvid/animation_dataset_clips_part_0.json'
168 | - dataset_types:
169 | - video_blip
170 | train_data:
171 | json_path: '/webvid/animation_dataset_clips_part_1.json'
172 | - dataset_types:
173 | - video_blip
174 | train_data:
175 | json_path: '/webvid/animation_dataset_clips_part_2.json'
176 | - dataset_types:
177 | - video_blip
178 | train_data:
179 | json_path: '/webvid/animation0.json'
180 | - dataset_types:
181 | - video_blip
182 | train_data:
183 | json_path: '/webvid/animation1.json'
184 |
185 |
186 | # Validation data parameters.
187 | validation_data:
188 |
189 | # A custom prompt that is different from your training dataset.
190 | prompt: "The fish is swimming."
191 |
192 | prompt_image: "output/example/fish_512.mp4"
193 |
194 | # Whether or not to sample preview during training (Requires more VRAM).
195 | sample_preview: True
196 |
197 | # The number of frames to sample during validation.
198 | num_frames: 14
199 |
200 | # Height and width of validation sample.
201 | width: 512
202 | height: 512
203 | # Number of inference steps when generating the video.
204 | num_inference_steps: 25
205 | fps: 7
206 | # CFG scale
207 | guidance_scale: 3
208 |
209 | # Learning rate for AdamW
210 | learning_rate: 5e-6
211 |
212 | # Weight decay. Higher = more regularization. Lower = closer to dataset.
213 | adam_weight_decay: 0
214 |
215 | # Optimizer parameters for the UNET. Overrides base learning rate parameters.
216 | extra_unet_params: null
217 | #learning_rate: 1e-5
218 | #adam_weight_decay: 1e-4
219 |
220 | # Optimizer parameters for the Text Encoder. Overrides base learning rate parameters.
221 | extra_text_encoder_params: null
222 | #learning_rate: 1e-4
223 | #adam_weight_decay: 0.2
224 |
225 | # How many batches to train. Not to be confused with video frames.
226 | train_batch_size: 1
227 | gradient_accumulation_steps: 2
228 | # Maximum number of train steps. Model is saved after training.
229 | max_train_steps: 10000
230 |
231 | # Saves a model every nth step.
232 | checkpointing_steps: 1000
233 |
234 | # How many steps to do for validation if sample_preview is enabled.
235 | validation_steps: 100
236 |
237 | # Which modules we want to unfreeze for the UNET. Advanced usage.
238 | trainable_modules:
239 | - "all"
240 | # If you want to ignore temporal attention entirely, remove "attn1-2" and replace with ".attentions"
241 | # This is for self attetion. Activates for spatial and temporal dimensions if n_sample_frames > 1
242 | #- "attn1"
243 | - ".attentions"
244 |
245 | # This is for cross attention (image & text data). Activates for spatial and temporal dimensions if n_sample_frames > 1
246 | #- "attn2"
247 |
248 | # sample input and output
249 | - "conv_in"
250 | - "conv_out"
251 |
252 | # Time condition
253 | - '_proj'
254 | - '_embedding'
255 |
256 |
257 | # Which modules we want to unfreeze for the Text Encoder. Advanced usage.
258 | trainable_text_modules:
259 | null
260 | #- "embedding"
261 |
262 | # Seed for validation.
263 | seed: 6
264 |
265 | # Whether or not we want to use mixed precision with accelerate
266 | mixed_precision: "fp16"
267 |
268 | # This seems to be incompatible at the moment.
269 | use_8bit_adam: False
270 |
271 | # Trades VRAM usage for speed. You lose roughly 20% of training speed, but save a lot of VRAM.
272 | # If you need to save more VRAM, it can also be enabled for the text encoder, but reduces speed x2.
273 | gradient_checkpointing: True
274 | text_encoder_gradient_checkpointing: False
275 |
276 | # Xformers must be installed for best memory savings and performance (< Pytorch 2.0)
277 | enable_xformers_memory_efficient_attention: False
278 |
279 | # Use scaled dot product attention (Only available with >= Torch 2.0)
280 | enable_torch_2_attn: True
281 |
--------------------------------------------------------------------------------
/example/validation_file.json:
--------------------------------------------------------------------------------
1 | [
2 | ["example/fish1.jpg", "The red fish is swimming"],
3 | ["example/barbie2.jpg", "a girl is talking, move head"],
4 | ["example/hulu3.jpg", "The man is talking, move hands."],
5 | ["example/pig0.jpg", "Three cartoon pigs are talking."],
6 | ["example/qingming2.jpg", "ships are sailing on the river."]
7 | ]
8 |
9 |
--------------------------------------------------------------------------------
/models/layerdiffuse_VAE.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | from typing import Optional, Tuple
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models.modeling_utils import ModelMixin
7 | from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
8 |
9 | # referenced from https://github.com/layerdiffusion/sd-forge-layerdiffuse/blob/main/lib_layerdiffusion/models.py
10 |
11 | def zero_module(module):
12 | for p in module.parameters():
13 | p.detach().zero_()
14 | return module
15 |
16 |
17 | class LatentTransparencyOffsetEncoder(torch.nn.Module):
18 | def __init__(self, *args, **kwargs):
19 | super().__init__(*args, **kwargs)
20 | self.blocks = torch.nn.Sequential(
21 | torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1),
22 | nn.SiLU(),
23 | torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
24 | nn.SiLU(),
25 | torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
26 | nn.SiLU(),
27 | torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
28 | nn.SiLU(),
29 | torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
30 | nn.SiLU(),
31 | torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
32 | nn.SiLU(),
33 | torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
34 | nn.SiLU(),
35 | torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
36 | nn.SiLU(),
37 | zero_module(torch.nn.Conv2d(256, 4, kernel_size=3, padding=1, stride=1)),
38 | )
39 |
40 | def __call__(self, x):
41 | return self.blocks(x)
42 |
43 |
44 | class UNet384(ModelMixin, ConfigMixin):
45 | @register_to_config
46 | def __init__(
47 | self,
48 | in_channels: int = 3,
49 | out_channels: int = 4,
50 | down_block_types: Tuple[str] = ("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
51 | up_block_types: Tuple[str] = ("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
52 | block_out_channels: Tuple[int] = (32, 64, 128, 256),
53 | layers_per_block: int = 2,
54 | mid_block_scale_factor: float = 1,
55 | downsample_padding: int = 1,
56 | downsample_type: str = "conv",
57 | upsample_type: str = "conv",
58 | dropout: float = 0.0,
59 | act_fn: str = "silu",
60 | attention_head_dim: Optional[int] = 8,
61 | norm_num_groups: int = 4,
62 | norm_eps: float = 1e-5,
63 | ):
64 | super().__init__()
65 |
66 | # input
67 | self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
68 | self.latent_conv_in = zero_module(nn.Conv2d(4, block_out_channels[2], kernel_size=1))
69 |
70 | self.down_blocks = nn.ModuleList([])
71 | self.mid_block = None
72 | self.up_blocks = nn.ModuleList([])
73 |
74 | # down
75 | output_channel = block_out_channels[0]
76 | for i, down_block_type in enumerate(down_block_types):
77 | input_channel = output_channel
78 | output_channel = block_out_channels[i]
79 | is_final_block = i == len(block_out_channels) - 1
80 |
81 | down_block = get_down_block(
82 | down_block_type,
83 | num_layers=layers_per_block,
84 | in_channels=input_channel,
85 | out_channels=output_channel,
86 | temb_channels=None,
87 | add_downsample=not is_final_block,
88 | resnet_eps=norm_eps,
89 | resnet_act_fn=act_fn,
90 | resnet_groups=norm_num_groups,
91 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
92 | downsample_padding=downsample_padding,
93 | resnet_time_scale_shift="default",
94 | downsample_type=downsample_type,
95 | dropout=dropout,
96 | )
97 | self.down_blocks.append(down_block)
98 |
99 | # mid
100 | self.mid_block = UNetMidBlock2D(
101 | in_channels=block_out_channels[-1],
102 | temb_channels=None,
103 | dropout=dropout,
104 | resnet_eps=norm_eps,
105 | resnet_act_fn=act_fn,
106 | output_scale_factor=mid_block_scale_factor,
107 | resnet_time_scale_shift="default",
108 | attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
109 | resnet_groups=norm_num_groups,
110 | attn_groups=None,
111 | add_attention=True,
112 | )
113 |
114 | # up
115 | reversed_block_out_channels = list(reversed(block_out_channels))
116 | output_channel = reversed_block_out_channels[0]
117 | for i, up_block_type in enumerate(up_block_types):
118 | prev_output_channel = output_channel
119 | output_channel = reversed_block_out_channels[i]
120 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
121 |
122 | is_final_block = i == len(block_out_channels) - 1
123 |
124 | up_block = get_up_block(
125 | up_block_type,
126 | num_layers=layers_per_block + 1,
127 | in_channels=input_channel,
128 | out_channels=output_channel,
129 | prev_output_channel=prev_output_channel,
130 | temb_channels=None,
131 | add_upsample=not is_final_block,
132 | resnet_eps=norm_eps,
133 | resnet_act_fn=act_fn,
134 | resnet_groups=norm_num_groups,
135 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
136 | resnet_time_scale_shift="default",
137 | upsample_type=upsample_type,
138 | dropout=dropout,
139 | )
140 | self.up_blocks.append(up_block)
141 | prev_output_channel = output_channel
142 |
143 | # out
144 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
145 | self.conv_act = nn.SiLU()
146 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
147 |
148 | def forward(self, x, latent):
149 | sample_latent = self.latent_conv_in(latent)
150 | sample = self.conv_in(x)
151 | emb = None
152 |
153 | down_block_res_samples = (sample,)
154 | for i, downsample_block in enumerate(self.down_blocks):
155 | # 8X downsample
156 | if i == 3:
157 | sample = sample + sample_latent
158 |
159 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
160 | down_block_res_samples += res_samples
161 |
162 | assert len(self.down_blocks) == 4
163 |
164 | sample = self.mid_block(sample, emb)
165 |
166 | for upsample_block in self.up_blocks:
167 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
168 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
169 | sample = upsample_block(sample, res_samples, emb)
170 |
171 | sample = self.conv_norm_out(sample)
172 | sample = self.conv_act(sample)
173 | sample = self.conv_out(sample)
174 | return sample
175 |
176 | def __call__(self, x, latent):
177 | return self.forward(x, latent)
178 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | torch==2.0.0
3 | torchvision
4 | diffusers==0.24.0
5 | transformers==4.36.2
6 | einops
7 | decord
8 | tqdm
9 | safetensors
10 | omegaconf
11 | opencv-python
12 | pydantic
13 | compel
14 | easydict
15 | rotary_embedding_torch
16 | imageio[ffmpeg]
17 | gradio
18 | httpx[socks]
19 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | python train.py --config ./configs/v2/infer_config_latent.yaml --eval
2 | accelerate launch train.py --config configs/v2/train_config_latent.yaml
3 |
4 |
--------------------------------------------------------------------------------
/stable_lora/lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import os
5 | import loralib as loralb
6 | from loralib import LoRALayer
7 | import math
8 | import json
9 |
10 | from torch.utils.data import ConcatDataset
11 | from transformers import CLIPTokenizer
12 |
13 | try:
14 | from safetensors.torch import save_file, load_file
15 | except:
16 | print("Safetensors is not installed. Saving while using use_safetensors will fail.")
17 |
18 | UNET_REPLACE = ["Transformer2DModel", "ResnetBlock2D"]
19 | TEXT_ENCODER_REPLACE = ["CLIPAttention", "CLIPTextEmbeddings"]
20 |
21 | UNET_ATTENTION_REPLACE = ["CrossAttention"]
22 | TEXT_ENCODER_ATTENTION_REPLACE = ["CLIPAttention", "CLIPTextEmbeddings"]
23 |
24 | """
25 | Copied from: https://github.com/cloneofsimo/lora/blob/bdd51b04c49fa90a88919a19850ec3b4cf3c5ecd/lora_diffusion/lora.py#L189
26 | """
27 | def find_modules(
28 | model,
29 | ancestor_class= None,
30 | search_class = [torch.nn.Linear],
31 | exclude_children_of = [loralb.Linear, loralb.Conv2d, loralb.Embedding],
32 | ):
33 | """
34 | Find all modules of a certain class (or union of classes) that are direct or
35 | indirect descendants of other modules of a certain class (or union of classes).
36 |
37 | Returns all matching modules, along with the parent of those moduless and the
38 | names they are referenced by.
39 | """
40 |
41 | # Get the targets we should replace all linears under
42 | if ancestor_class is not None:
43 | ancestors = (
44 | module
45 | for module in model.modules()
46 | if module.__class__.__name__ in ancestor_class
47 | )
48 | else:
49 | # this, incase you want to naively iterate over all modules.
50 | ancestors = [module for module in model.modules()]
51 |
52 | # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
53 | for ancestor in ancestors:
54 | for fullname, module in ancestor.named_modules():
55 | if any([isinstance(module, _class) for _class in search_class]):
56 | # Find the direct parent if this is a descendant, not a child, of target
57 | *path, name = fullname.split(".")
58 | parent = ancestor
59 | while path:
60 | parent = parent.get_submodule(path.pop(0))
61 | # Skip this linear if it's a child of a LoraInjectedLinear
62 | if exclude_children_of and any(
63 | [isinstance(parent, _class) for _class in exclude_children_of]
64 | ):
65 | continue
66 | # Otherwise, yield it
67 | yield parent, name, module
68 |
69 | class Conv2d(nn.Conv2d, LoRALayer):
70 | # LoRA implemented in a dense layer
71 | def __init__(
72 | self,
73 | in_channels: int,
74 | out_channels: int,
75 | kernel_size: int,
76 | r: int = 0,
77 | lora_alpha: int = 1,
78 | lora_dropout: float = 0.,
79 | merge_weights: bool = True,
80 | **kwargs
81 | ):
82 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
83 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
84 | merge_weights=merge_weights)
85 | assert type(kernel_size) is int
86 | # Actual trainable parameters
87 | if r > 0:
88 | self.lora_A = nn.Parameter(
89 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
90 | )
91 | self.lora_B = nn.Parameter(
92 | self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
93 | )
94 | self.scaling = self.lora_alpha / self.r
95 | # Freezing the pre-trained weight matrix
96 | self.weight.requires_grad = False
97 | self.reset_parameters()
98 |
99 | def reset_parameters(self):
100 | nn.Conv2d.reset_parameters(self)
101 | if hasattr(self, 'lora_A'):
102 | # initialize A the same way as the default for nn.Linear and B to zero
103 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
104 | nn.init.zeros_(self.lora_B)
105 |
106 | def train(self, mode: bool = True):
107 | nn.Conv2d.train(self, mode)
108 | if mode:
109 | if self.merge_weights and self.merged:
110 | # Make sure that the weights are not merged
111 | self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
112 | self.merged = False
113 | else:
114 | if self.merge_weights and not self.merged:
115 | # Merge the weights and mark it
116 | self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
117 | self.merged = True
118 |
119 | def forward(self, x: torch.Tensor):
120 | if self.r > 0 and not self.merged:
121 | return F.conv2d(
122 | x,
123 | self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
124 | self.bias, self.stride, self.padding, self.dilation, self.groups
125 | )
126 | return nn.Conv2d.forward(self, x)
127 |
128 | class Conv3d(nn.Conv3d, LoRALayer):
129 | # LoRA implemented in a dense layer
130 | def __init__(
131 | self,
132 | in_channels: int,
133 | out_channels: int,
134 | kernel_size: int,
135 | r: int = 0,
136 | lora_alpha: int = 1,
137 | lora_dropout: float = 0.,
138 | merge_weights: bool = True,
139 | **kwargs
140 | ):
141 | nn.Conv3d.__init__(self, in_channels, out_channels, (kernel_size, 1, 1), **kwargs)
142 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
143 | merge_weights=merge_weights)
144 | assert type(kernel_size) is int
145 | # Actual trainable parameters
146 |
147 | # Get view transform shape
148 | i, o, k = self.weight.shape[:3]
149 | self.view_shape = (i, o, k, kernel_size, 1)
150 | self.force_disable_merge = True
151 |
152 | if r > 0:
153 | self.lora_A = nn.Parameter(
154 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
155 | )
156 | self.lora_B = nn.Parameter(
157 | self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
158 | )
159 | self.scaling = self.lora_alpha / self.r
160 | # Freezing the pre-trained weight matrix
161 | self.weight.requires_grad = False
162 | self.reset_parameters()
163 |
164 | def reset_parameters(self):
165 | nn.Conv3d.reset_parameters(self)
166 | if hasattr(self, 'lora_A'):
167 | # initialize A the same way as the default for nn.Linear and B to zero
168 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
169 | nn.init.zeros_(self.lora_B)
170 |
171 | def train(self, mode: bool = True):
172 | nn.Conv3d.train(self, mode)
173 |
174 | # HACK Merging the weights this way could potentially cause vanishing gradients if validation is enabled.
175 | # If you are to save this as a pretrained model, you will have to merge these weights afterwards, then save.
176 | if self.force_disable_merge:
177 | return
178 |
179 | if mode:
180 | if self.merge_weights and self.merged:
181 | # Make sure that the weights are not merged
182 | self.weight.data -= torch.mean((self.lora_B @ self.lora_A).view(self.view_shape), dim=-2, keepdim=True) * self.scaling
183 | self.merged = False
184 | else:
185 | if self.merge_weights and not self.merged:
186 | # Merge the weights and mark it
187 | self.weight.data += torch.mean((self.lora_B @ self.lora_A).view(self.view_shape), dim=-2, keepdim=True) * self.scaling
188 | self.merged = True
189 |
190 | def forward(self, x: torch.Tensor):
191 | if self.r > 0 and not self.merged:
192 | return F.conv3d(
193 | x,
194 | self.weight + torch.mean((self.lora_B @ self.lora_A).view(self.view_shape), dim=-2, keepdim=True) * \
195 | self.scaling, self.bias, self.stride, self.padding, self.dilation, self.groups
196 | )
197 | return nn.Conv3d.forward(self, x)
198 |
199 | def create_lora_linear(child_module, r, dropout=0, bias=False, scale=0):
200 | return loralb.Linear(
201 | child_module.in_features,
202 | child_module.out_features,
203 | merge_weights=False,
204 | bias=bias,
205 | lora_dropout=dropout,
206 | lora_alpha=r,
207 | r=r
208 | )
209 | return lora_linear
210 |
211 | def create_lora_conv(child_module, r, dropout=0, bias=False, rescale=False, scale=0):
212 | return Conv2d(
213 | child_module.in_channels,
214 | child_module.out_channels,
215 | kernel_size=child_module.kernel_size[0],
216 | padding=child_module.padding,
217 | stride=child_module.stride,
218 | merge_weights=False,
219 | bias=bias,
220 | lora_dropout=dropout,
221 | lora_alpha=r,
222 | r=r,
223 | )
224 | return lora_conv
225 |
226 | def create_lora_conv3d(child_module, r, dropout=0, bias=False, rescale=False, scale=0):
227 | return Conv3d(
228 | child_module.in_channels,
229 | child_module.out_channels,
230 | kernel_size=child_module.kernel_size[0],
231 | padding=child_module.padding,
232 | stride=child_module.stride,
233 | merge_weights=False,
234 | bias=bias,
235 | lora_dropout=dropout,
236 | lora_alpha=r,
237 | r=r,
238 | )
239 | return lora_conv
240 |
241 | def create_lora_emb(child_module, r):
242 | return loralb.Embedding(
243 | child_module.num_embeddings,
244 | child_module.embedding_dim,
245 | merge_weights=False,
246 | lora_alpha=r,
247 | r=r
248 | )
249 |
250 | def activate_lora_train(model, bias):
251 | def unfreeze():
252 | print(model.__class__.__name__ + " LoRA set for training.")
253 | return loralb.mark_only_lora_as_trainable(model, bias=bias)
254 |
255 | return unfreeze
256 |
257 | def add_lora_to(
258 | model,
259 | target_module=UNET_REPLACE,
260 | search_class=[torch.nn.Linear],
261 | r=32,
262 | dropout=0,
263 | lora_bias='none'
264 | ):
265 | for module, name, child_module in find_modules(
266 | model,
267 | ancestor_class=target_module,
268 | search_class=search_class
269 | ):
270 | bias = hasattr(child_module, "bias")
271 |
272 | # Check if child module of the model has bias.
273 | if bias:
274 | if child_module.bias is None:
275 | bias = False
276 |
277 | # Check if the child module of the model is type Linear or Conv2d.
278 | if isinstance(child_module, torch.nn.Linear):
279 | l = create_lora_linear(child_module, r, dropout, bias=bias)
280 |
281 | if isinstance(child_module, torch.nn.Conv2d):
282 | l = create_lora_conv(child_module, r, dropout, bias=bias)
283 |
284 | if isinstance(child_module, torch.nn.Conv3d):
285 | l = create_lora_conv3d(child_module, r, dropout, bias=bias)
286 |
287 | if isinstance(child_module, torch.nn.Embedding):
288 | l = create_lora_emb(child_module, r)
289 |
290 | # If the model has bias and we wish to add it, use the child_modules in place
291 | if bias:
292 | l.bias = child_module.bias
293 |
294 | # Assign the frozen weight of model's Linear or Conv2d to the LoRA model.
295 | l.weight = child_module.weight
296 |
297 | # Replace the new LoRA model with the model's Linear or Conv2d module.
298 | module._modules[name] = l
299 |
300 |
301 | # Unfreeze only the newly added LoRA weights, but keep the model frozen.
302 | return activate_lora_train(model, lora_bias)
303 |
304 | def save_lora(
305 | unet=None,
306 | text_encoder=None,
307 | save_text_weights=False,
308 | output_dir="output",
309 | lora_filename="lora.safetensors",
310 | lora_bias='none',
311 | save_for_webui=True,
312 | only_webui=False,
313 | metadata=None,
314 | unet_dict_converter=None,
315 | text_dict_converter=None
316 | ):
317 |
318 | if not only_webui:
319 | # Create directory for the full LoRA weights.
320 | trainable_weights_dir = f"{output_dir}/full_weights"
321 | lora_out_file_full_weight = f"{trainable_weights_dir}/{lora_filename}"
322 | os.makedirs(trainable_weights_dir, exist_ok=True)
323 |
324 | ext = '.safetensors'
325 | # Create LoRA out filename.
326 | lora_out_file = f"{output_dir}/webui_{lora_filename}{ext}"
327 |
328 | if not only_webui:
329 | save_path_full_weights = lora_out_file_full_weight + ext
330 |
331 | save_path = lora_out_file
332 |
333 | if not only_webui:
334 | for i, model in enumerate([unet, text_encoder]):
335 | if save_text_weights and i == 1:
336 | non_webui_weights = save_path_full_weights.replace(ext, f"_text_encoder{ext}")
337 |
338 | else:
339 | non_webui_weights = save_path_full_weights.replace(ext, f"_unet{ext}")
340 |
341 | # Load only the LoRAs from the state dict.
342 | lora_dict = loralb.lora_state_dict(model, bias=lora_bias)
343 |
344 | # Save the models as fp32. This ensures we can finetune again without having to upcast.
345 | save_file(lora_dict, non_webui_weights)
346 |
347 | if save_for_webui:
348 | # Convert the keys to compvis model and webui
349 | unet_lora_dict = loralb.lora_state_dict(unet, bias=lora_bias)
350 | lora_dict_fp16 = unet_dict_converter(unet_lora_dict, strict_mapping=True)
351 |
352 | if save_text_weights:
353 | text_encoder_dict = loralb.lora_state_dict(text_encoder, bias=lora_bias)
354 | lora_dict_text_fp16 = text_dict_converter(text_encoder_dict)
355 |
356 | # Update the Unet dict to include text keys.
357 | lora_dict_fp16.update(lora_dict_text_fp16)
358 |
359 | # Cast tensors to fp16. It's assumed we won't be finetuning these.
360 | for k, v in lora_dict_fp16.items():
361 | lora_dict_fp16[k] = v.to(dtype=torch.float16)
362 |
363 | save_file(
364 | lora_dict_fp16,
365 | save_path,
366 | metadata=metadata
367 | )
368 |
369 | def load_lora(model, lora_path: str):
370 | try:
371 | if os.path.exists(lora_path):
372 | lora_dict = load_file(lora_path)
373 | model.load_state_dict(lora_dict, strict=False)
374 |
375 | except Exception as e:
376 | print(f"Could not load your lora file: {e}")
377 |
378 | def set_mode(model, train=False):
379 | for n, m in model.named_modules():
380 | is_lora = hasattr(m, 'merged')
381 | if is_lora:
382 | m.train(train)
383 |
384 | def set_mode_group(models, train):
385 | for model in models:
386 | set_mode(model, train)
387 | model.train(train)
388 |
--------------------------------------------------------------------------------
/svd_video2video_examples/barbie_input.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/barbie_input.mp4
--------------------------------------------------------------------------------
/svd_video2video_examples/barbie_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/barbie_mask.png
--------------------------------------------------------------------------------
/svd_video2video_examples/barbie_output.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/barbie_output.mp4
--------------------------------------------------------------------------------
/svd_video2video_examples/car_input.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_input.mp4
--------------------------------------------------------------------------------
/svd_video2video_examples/car_mask_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_mask_1.png
--------------------------------------------------------------------------------
/svd_video2video_examples/car_mask_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_mask_2.png
--------------------------------------------------------------------------------
/svd_video2video_examples/car_output_1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_output_1.mp4
--------------------------------------------------------------------------------
/svd_video2video_examples/car_output_2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/car_output_2.mp4
--------------------------------------------------------------------------------
/svd_video2video_examples/windmill_input.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/windmill_input.mp4
--------------------------------------------------------------------------------
/svd_video2video_examples/windmill_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/windmill_mask.png
--------------------------------------------------------------------------------
/svd_video2video_examples/windmill_output.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/svd_video2video_examples/windmill_output.mp4
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/animate-anything/dca241202348fa2ac68e22e81fca042fe05e3550/utils/__init__.py
--------------------------------------------------------------------------------
/utils/bucketing.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | def min_res(size, min_size): return 192 if size < 192 else size
4 |
5 | def up_down_bucket(m_size, in_size, direction):
6 | if direction == 'down': return abs(int(m_size - in_size))
7 | if direction == 'up': return abs(int(m_size + in_size))
8 |
9 | def get_bucket_sizes(size, direction: 'down', min_size):
10 | multipliers = [64, 128]
11 | for i, m in enumerate(multipliers):
12 | res = up_down_bucket(m, size, direction)
13 | multipliers[i] = min_res(res, min_size=min_size)
14 | return multipliers
15 |
16 | def closest_bucket(m_size, size, direction, min_size):
17 | lst = get_bucket_sizes(m_size, direction, min_size)
18 | return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))]
19 |
20 | def resolve_bucket(i,h,w): return (i / (h / w))
21 |
22 | def sensible_buckets(m_width, m_height, w, h, min_size=192):
23 | if h > w:
24 | w = resolve_bucket(m_width, h, w)
25 | w = closest_bucket(m_width, w, 'down', min_size=min_size)
26 | return w, m_height
27 | if h < w:
28 | h = resolve_bucket(m_height, w, h)
29 | h = closest_bucket(m_height, h, 'down', min_size=min_size)
30 | return m_width, h
31 |
32 | return m_width, m_height
--------------------------------------------------------------------------------
/utils/common.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import json
3 | from PIL import Image
4 | import torch
5 | import random
6 | import numpy as np
7 | import torchvision.transforms as T
8 | from einops import rearrange, repeat
9 | import imageio
10 | import sys
11 |
12 | def tensor_to_vae_latent(t, vae):
13 | video_length = t.shape[1]
14 |
15 | t = rearrange(t, "b f c h w -> (b f) c h w")
16 | latents = vae.encode(t).latent_dist.mode()
17 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
18 | latents = latents * 0.18215
19 |
20 | return latents
21 |
22 | def DDPM_forward(x0, step, num_frames, scheduler):
23 | device = x0.device
24 | t = scheduler.timesteps[-1]
25 | xt = repeat(x0, 'b c 1 h w -> b c f h w', f = num_frames)
26 |
27 | eps = torch.randn_like(xt)
28 | alpha_vec = torch.prod(scheduler.alphas[t:])
29 | xt = torch.sqrt(alpha_vec) * xt + torch.sqrt(1-alpha_vec) * eps
30 | return xt, None
31 |
32 | def DDPM_forward_timesteps(x0, step, num_frames, scheduler):
33 | '''larger step -> smaller t -> smaller alphas[t:] -> smaller xt -> smaller x0'''
34 |
35 | device = x0.device
36 | # timesteps are reversed
37 | timesteps = scheduler.timesteps[len(scheduler.timesteps)-step:]
38 | t = timesteps[0]
39 |
40 | if x0.shape[2] == 1:
41 | xt = repeat(x0, 'b c 1 h w -> b c f h w', f = num_frames)
42 | else:
43 | xt = x0
44 | noise = torch.randn(xt.shape, dtype=xt.dtype, device=device)
45 | # t to tensor of batch size
46 | t = torch.tensor([t]*xt.shape[0], device=device)
47 | xt = scheduler.add_noise(xt, noise, t)
48 | return xt, timesteps
49 |
50 | def DDPM_forward_mask(x0, step, num_frames, scheduler, mask):
51 | '''larger step -> smaller t -> smaller alphas[t:] -> smaller xt -> smaller x0'''
52 | device = x0.device
53 | dtype = x0.dtype
54 | b, c, f, h, w = x0.shape
55 |
56 | move_xt, timesteps = DDPM_forward_timesteps(x0, step, num_frames, scheduler)
57 | mask = T.ToTensor()(mask).to(dtype).to(device)
58 | mask = T.Resize([h, w], antialias=False)(mask)
59 | mask = rearrange(mask, 'b h w -> b 1 1 h w')
60 | freeze_xt = repeat(x0, 'b c 1 h w -> b c f h w', f = num_frames)
61 | initial = freeze_xt * (1-mask) + move_xt * mask
62 | return initial, timesteps
63 |
64 | def read_video(video_path, frame_number=-1):
65 | # Open the video file
66 | cap = cv2.VideoCapture(video_path)
67 | count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
68 | if frame_number == -1:
69 | frame_number = count
70 | else:
71 | frame_number = min(frame_number, count)
72 | frames = []
73 | for i in range(frame_number):
74 | ret, ref_frame = cap.read()
75 | ref_frame = cv2.cvtColor(ref_frame, cv2.COLOR_BGR2RGB)
76 | if not ret:
77 | raise ValueError("Failed to read video file")
78 | frames.append(ref_frame)
79 | return frames
80 |
81 | def get_full_white_area_mask(frames):
82 | ref_frame = frames[0]
83 | ref_gray = cv2.cvtColor(ref_frame, cv2.COLOR_BGR2GRAY)
84 | total_mask = np.ones_like(ref_gray) * 255
85 |
86 | return total_mask
87 |
88 | def get_moved_area_mask(frames, move_th=5, th=-1):
89 | ref_frame = frames[0]
90 | # Convert the reference frame to gray
91 | ref_gray = cv2.cvtColor(ref_frame, cv2.COLOR_BGR2GRAY)
92 | prev_gray = ref_gray
93 | # Initialize the total accumulated motion mask
94 | total_mask = np.zeros_like(ref_gray)
95 |
96 | # Iterate through the video frames
97 | for i in range(1, len(frames)):
98 | frame = frames[i]
99 | # Convert the frame to gray
100 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
101 |
102 | # Compute the absolute difference between the reference frame and the current frame
103 | diff = cv2.absdiff(ref_gray, gray)
104 | #diff += cv2.absdiff(prev_gray, gray)
105 |
106 | # Apply a threshold to obtain a binary image
107 | ret, mask = cv2.threshold(diff, move_th, 255, cv2.THRESH_BINARY)
108 |
109 | # Accumulate the mask
110 | total_mask = cv2.bitwise_or(total_mask, mask)
111 |
112 | # Update the reference frame
113 | prev_gray = gray
114 |
115 | contours, _ = cv2.findContours(total_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
116 | rects = []
117 | ref_mask = np.zeros_like(ref_gray)
118 | ref_mask = cv2.drawContours(ref_mask, contours, -1, (255, 255, 255), -1)
119 | for cnt in contours:
120 | cur_rec = cv2.boundingRect(cnt)
121 | rects.append(cur_rec)
122 |
123 | #rects = merge_overlapping_rectangles(rects)
124 | mask = np.zeros_like(ref_gray)
125 | if th < 0:
126 | h, w = mask.shape
127 | th = int(h*w*0.005)
128 | for rect in rects:
129 | x, y, w, h = rect
130 | if w*h < th:
131 | continue
132 | #ref_frame = cv2.rectangle(ref_frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
133 | mask[y:y+h, x:x+w] = 255
134 | return mask
135 |
136 | def calculate_motion_precision(frames, mask):
137 | moved_mask = get_moved_area_mask(frames, move_th=20, th=0)
138 | moved = moved_mask == 255
139 | gt = mask == 255
140 | precision = np.sum(moved & gt) / np.sum(moved)
141 | return precision
142 |
143 | def check_overlap(rect1, rect2):
144 | # Calculate the coordinates of the edges of the rectangles
145 | rect1_left = rect1[0]
146 | rect1_right = rect1[0] + rect1[2]
147 | rect1_top = rect1[1]
148 | rect1_bottom = rect1[1] + rect1[3]
149 |
150 | rect2_left = rect2[0]
151 | rect2_right = rect2[0] + rect2[2]
152 | rect2_top = rect2[1]
153 | rect2_bottom = rect2[1] + rect2[3]
154 |
155 | # Check if the rectangles overlap
156 | if (rect2_left >= rect1_right or rect2_right <= rect1_left or
157 | rect2_top >= rect1_bottom or rect2_bottom <= rect1_top):
158 | return False
159 | else:
160 | return True
161 |
162 | def merge_rects(rect1, rect2):
163 | left = min(rect1[0], rect2[0])
164 | top = min(rect1[1], rect2[1])
165 | right = max(rect1[0]+rect1[2], rect2[0]+rect2[2])
166 | bottom = max(rect1[1]+rect1[3], rect2[1]+rect2[3])
167 | width = right - left
168 | height = bottom - top
169 | return (left, top, width, height)
170 |
171 | def merge_overlapping_rectangles(rectangles):
172 | # Sort the rectangles based on their left coordinate
173 | sorted_rectangles = sorted(rectangles, key=lambda x: x[0])
174 |
175 | # Initialize an empty list to store the merged rectangles
176 | merged_rectangles = []
177 |
178 | # Iterate through the sorted rectangles and merge them
179 | for rect in sorted_rectangles:
180 | if not merged_rectangles:
181 | # If the merged rectangles list is empty, add the first rectangle to it
182 | merged_rectangles.append(rect)
183 | else:
184 | # Get the last merged rectangle
185 | last_merged = merged_rectangles[-1]
186 |
187 | # Check if the current rectangle overlaps with the last merged rectangle
188 | if last_merged[0] + last_merged[2] >= rect[0]:
189 | # Merge the rectangles if they overlap
190 | merged_rectangles[-1] = (
191 | min(last_merged[0], rect[0]),
192 | min(last_merged[1], rect[1]),
193 | max(last_merged[0] + last_merged[2], rect[0] + rect[2]) - min(last_merged[0], rect[0]),
194 | max(last_merged[1] + last_merged[3], rect[1] + rect[3]) - min(last_merged[1], rect[1])
195 | )
196 | else:
197 | # Add the current rectangle to the merged rectangles list if they don't overlap
198 | merged_rectangles.append(rect)
199 |
200 | return merged_rectangles
201 |
202 | def generate_random_mask(image):
203 | # Create a blank mask with the same size as the image
204 | b, c , h, w = image.shape
205 | mask = np.zeros([b, h, w], dtype=np.uint8)
206 |
207 | # Generate random coordinates for the mask
208 | num_points = np.random.randint(3, 10) # Randomly choose the number of points to generate
209 | points = np.random.randint(0, min(h, w), size=(num_points, 2)) # Randomly generate the points
210 | # Draw a filled polygon on the mask using the random points
211 | for i in range(b):
212 | width = random.randint(w//4, w)
213 | height = random.randint(h//4, h)
214 | x = random.randint(0, w-width)
215 | y = random.randint(0, h-height)
216 | points=np.array([[x, y], [x+width, y], [x+width, y+height], [x, y+height]])
217 | mask[i] = cv2.fillPoly(mask[i], [points], 255)
218 |
219 | # Apply the mask to the image
220 | #masked_image = cv2.bitwise_and(image, image, mask=mask)
221 | return mask
222 |
223 | def generate_center_mask(image):
224 | # Create a blank mask with the same size as the image
225 | b, c , h, w = image.shape
226 | mask = np.zeros([b, h, w], dtype=np.uint8)
227 |
228 | # Generate random coordinates for the mask
229 | for i in range(b):
230 | width = int(w/10)
231 | height = int(h/10)
232 | mask[i][height:-height,width:-width] = 255
233 | # Apply the mask to the image
234 | #masked_image = cv2.bitwise_and(image, image, mask=mask)
235 | return mask
236 |
237 | def read_mask(json_path, label=["mask"]):
238 | j = json.load(open(json_path))
239 | if type(label) != list:
240 | labels = [label]
241 | height = j['imageHeight']
242 | width = j['imageWidth']
243 | mask = np.zeros([height, width], dtype=np.uint8)
244 | for shape in j['shapes']:
245 | if shape['label'] in label:
246 | x1, y1 = shape['points'][0]
247 | x2, y2 = shape['points'][1]
248 | mask[int(y1):int(y2), int(x1):int(x2)] = 255
249 | return mask
250 |
251 |
252 | def slerp(z1, z2, alpha):
253 | theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
254 | return (
255 | torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
256 | + torch.sin(alpha * theta) / torch.sin(theta) * z2
257 | )
258 |
259 | def _detect_edges(lum: np.ndarray, kernel_size=5) -> np.ndarray:
260 | """Detect edges using the luma channel of a frame.
261 |
262 | Arguments:
263 | lum: 2D 8-bit image representing the luma channel of a frame.
264 |
265 | Returns:
266 | 2D 8-bit image of the same size as the input, where pixels with values of 255
267 | represent edges, and all other pixels are 0.
268 | """
269 | # Initialize kernel.
270 | #kernel_size = _estimated_kernel_size(lum.shape[1], lum.shape[0])
271 | _kernel = np.ones((kernel_size, kernel_size), np.uint8)
272 |
273 | # Estimate levels for thresholding.
274 | # TODO(0.6.3): Add config file entries for sigma, aperture/kernel size, etc.
275 | sigma: float = 1.0 / 3.0
276 | median = np.median(lum)
277 | low = int(max(0, (1.0 - sigma) * median))
278 | high = int(min(255, (1.0 + sigma) * median))
279 |
280 | # Calculate edges using Canny algorithm, and reduce noise by dilating the edges.
281 | # This increases edge overlap leading to improved robustness against noise and slow
282 | # camera movement. Note that very large kernel sizes can negatively affect accuracy.
283 | edges = cv2.Canny(lum, low, high)
284 | return cv2.dilate(edges, _kernel)
285 |
286 |
287 | def _mean_pixel_distance(left: np.ndarray, right: np.ndarray) -> float:
288 | """Return the mean average distance in pixel values between `left` and `right`.
289 | Both `left and `right` should be 2 dimensional 8-bit images of the same shape.
290 | """
291 | assert len(left.shape) == 2 and len(right.shape) == 2
292 | assert left.shape == right.shape
293 | num_pixels: float = float(left.shape[0] * left.shape[1])
294 | return (np.sum(np.abs(left.astype(np.int32) - right.astype(np.int32))) / num_pixels)
295 |
296 | def calculate_latent_motion_score(latents):
297 | #latents b, c f, h, w
298 | diff=torch.abs(latents[:,:,1:]-latents[:,:,:-1])
299 | motion_score = torch.sum(torch.mean(diff, dim=[2,3,4]), dim=1) * 10
300 | return motion_score
301 |
302 | def motion_mask_loss(latents, mask):
303 | diff = torch.abs(latents[:,:,1:] - latents[:,:,:-1])
304 | loss = torch.sum(torch.mean(diff * (1-mask), dim=[2,3,4]), dim=1)
305 | return loss
306 |
307 | def calculate_motion_score(frame_imgs, calculate_edges=False, color="RGB") -> float:
308 | # Convert image into HSV colorspace.
309 | _last_frame = None
310 |
311 | _weights = [1.0, 1.0, 1.0, 0.0]
312 | score = 0
313 | for frame_img in frame_imgs:
314 | if color == "RGB":
315 | hue, sat, lum = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_RGB2HSV))
316 | else:
317 | hue, sat, lum = cv2.split(cv2.cvtColor(frame_img, cv2.COLOR_BGR2HSV))
318 | # Performance: Only calculate edges if we have to.
319 | edges = _detect_edges(lum) if calculate_edges else None
320 | if _last_frame == None:
321 | _last_frame = (hue, sat, lum, edges)
322 | continue
323 |
324 | score_components = [
325 | _mean_pixel_distance(hue, _last_frame[0]),
326 | _mean_pixel_distance(sat, _last_frame[1]),
327 | _mean_pixel_distance(lum, _last_frame[2]),
328 | 0.0 if edges is None else _mean_pixel_distance(edges, _last_frame[3]),
329 | ]
330 |
331 | frame_score: float = (
332 | sum(component * weight for (component, weight) in zip(score_components, _weights))
333 | / sum(abs(weight) for weight in _weights))
334 | score += frame_score
335 | _last_frame = (hue, sat, lum, edges)
336 |
337 | return round(score/(len(frame_imgs)-1) * 10)
338 |
339 | if __name__ == "__main__":
340 |
341 | # Example usage
342 | video_paths = [
343 | "/data/video/animate2/Bleach.Sennen.Kessen.Hen.S01E01.2022.1080p.WEB-DL.x264.AAC-DDHDTV-Scene-002.mp4",
344 | "/data/video/animate2/Evangelion.3.0.1.01.Thrice.Upon.A.Time.2021.BLURAY.720p.BluRay.x264.AAC-[YTS.MX]-Scene-0780.mp4",
345 | "/data/video/animate2/[GM-Team][国漫][永生 第2季][IMMORTALITY Ⅱ][2023][09][AVC][GB][1080P]-Scene-180.mp4",
346 | "/data/video/animate2/[orion origin] Legend of the Galactic Heroes Die Neue These [07] [WebRip 1080p] [H265 AAC] [GB]-Scene-048.mp4",
347 | "/data/video/MSRVTT/videos/all/video33.mp4",
348 | "/webvid/webvid/data/videos/000001_000050/1066692580.mp4",
349 | "/webvid/webvid/data/videos/000001_000050/1066685533.mp4",
350 | "/webvid/webvid/data/videos/000001_000050/1066685548.mp4",
351 | "/webvid/webvid/data/videos/000001_000050/1066676380.mp4",
352 | "/webvid/webvid/data/videos/000001_000050/1066676377.mp4",
353 | ]
354 | for i, video_path in enumerate(video_paths[:5]):
355 | frames = read_video(video_path, 200)[::3]
356 | if sys.argv[1] == 'test_mask':
357 | mask = get_moved_area_mask(frames)
358 | Image.fromarray(mask).save(f"output/mask/{i}.jpg")
359 | imageio.mimwrite(f"output/mask/{i}.gif", frames, duration=125, loop=0)
360 | elif sys.argv[1] == 'test_motion':
361 | for r in range(0, len(frames), 16):
362 | video_frames = frames[r:r+16]
363 | video_frames = [cv2.resize(f, (512, 512)) for f in video_frames]
364 | score = calculate_motion_score(video_frames, calculate_edges=False, color="BGR")
365 | imageio.mimwrite(f"output/example_video/{i}_{r}_{score}.mp4", video_frames, fps=8)
366 | elif sys.argv[1] == 'to_gif':
367 | imageio.mimwrite(f"output/example_video/{i}.gif", frames, duration=125, loop=0)
368 |
--------------------------------------------------------------------------------
/utils/convert_diffusers_to_original_ms_text_to_video.py:
--------------------------------------------------------------------------------
1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2 | # *Only* converts the UNet, and Text Encoder.
3 | # Does not convert optimizer state or any other thing.
4 |
5 | import argparse
6 | import os.path as osp
7 | import re
8 |
9 | import torch
10 | from safetensors.torch import load_file, save_file
11 |
12 | # =================#
13 | # UNet Conversion #
14 | # =================#
15 |
16 | print ('Initializing the conversion map')
17 |
18 | unet_conversion_map = [
19 | # (ModelScope, HF Diffusers)
20 |
21 | # from Vanilla ModelScope/StableDiffusion
22 | ("time_embed.0.weight", "time_embedding.linear_1.weight"),
23 | ("time_embed.0.bias", "time_embedding.linear_1.bias"),
24 | ("time_embed.2.weight", "time_embedding.linear_2.weight"),
25 | ("time_embed.2.bias", "time_embedding.linear_2.bias"),
26 |
27 |
28 | # from Vanilla ModelScope/StableDiffusion
29 | ("input_blocks.0.0.weight", "conv_in.weight"),
30 | ("input_blocks.0.0.bias", "conv_in.bias"),
31 |
32 |
33 | # from Vanilla ModelScope/StableDiffusion
34 | ("out.0.weight", "conv_norm_out.weight"),
35 | ("out.0.bias", "conv_norm_out.bias"),
36 | ("out.2.weight", "conv_out.weight"),
37 | ("out.2.bias", "conv_out.bias"),
38 | ]
39 |
40 | unet_conversion_map_resnet = [
41 | # (ModelScope, HF Diffusers)
42 |
43 | # SD
44 | ("in_layers.0", "norm1"),
45 | ("in_layers.2", "conv1"),
46 | ("out_layers.0", "norm2"),
47 | ("out_layers.3", "conv2"),
48 | ("emb_layers.1", "time_emb_proj"),
49 | ("skip_connection", "conv_shortcut"),
50 |
51 | # MS
52 | #("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha
53 | ]
54 |
55 | unet_conversion_map_layer = []
56 |
57 | # Convert input TemporalTransformer
58 | unet_conversion_map_layer.append(('input_blocks.0.1', 'transformer_in'))
59 |
60 | # Reference for the default settings
61 |
62 | # "model_cfg": {
63 | # "unet_in_dim": 4,
64 | # "unet_dim": 320,
65 | # "unet_y_dim": 768,
66 | # "unet_context_dim": 1024,
67 | # "unet_out_dim": 4,
68 | # "unet_dim_mult": [1, 2, 4, 4],
69 | # "unet_num_heads": 8,
70 | # "unet_head_dim": 64,
71 | # "unet_res_blocks": 2,
72 | # "unet_attn_scales": [1, 0.5, 0.25],
73 | # "unet_dropout": 0.1,
74 | # "temporal_attention": "True",
75 | # "num_timesteps": 1000,
76 | # "mean_type": "eps",
77 | # "var_type": "fixed_small",
78 | # "loss_type": "mse"
79 | # }
80 |
81 | # hardcoded number of downblocks and resnets/attentions...
82 | # would need smarter logic for other networks.
83 | for i in range(4):
84 | # loop over downblocks/upblocks
85 |
86 | for j in range(2):
87 | # loop over resnets/attentions for downblocks
88 |
89 | # Spacial SD stuff
90 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
91 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
92 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
93 |
94 | if i < 3:
95 | # no attention layers in down_blocks.3
96 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
97 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
98 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
99 |
100 | # Temporal MS stuff
101 | hf_down_res_prefix = f"down_blocks.{i}.temp_convs.{j}."
102 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0.temopral_conv."
103 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
104 |
105 | if i < 3:
106 | # no attention layers in down_blocks.3
107 | hf_down_atn_prefix = f"down_blocks.{i}.temp_attentions.{j}."
108 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.2."
109 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
110 |
111 | for j in range(3):
112 | # loop over resnets/attentions for upblocks
113 |
114 | # Spacial SD stuff
115 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
116 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
117 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
118 |
119 | if i > 0:
120 | # no attention layers in up_blocks.0
121 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
122 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
123 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
124 |
125 | # loop over resnets/attentions for upblocks
126 | hf_up_res_prefix = f"up_blocks.{i}.temp_convs.{j}."
127 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0.temopral_conv."
128 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
129 |
130 | if i > 0:
131 | # no attention layers in up_blocks.0
132 | hf_up_atn_prefix = f"up_blocks.{i}.temp_attentions.{j}."
133 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.2."
134 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
135 |
136 | # Up/Downsamplers are 2D, so don't need to touch them
137 | if i < 3:
138 | # no downsample in down_blocks.3
139 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
140 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.op."
141 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
142 |
143 | # no upsample in up_blocks.3
144 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
145 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 3}."
146 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
147 |
148 |
149 | # Handle the middle block
150 |
151 | # Spacial
152 | hf_mid_atn_prefix = "mid_block.attentions.0."
153 | sd_mid_atn_prefix = "middle_block.1."
154 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
155 |
156 | for j in range(2):
157 | hf_mid_res_prefix = f"mid_block.resnets.{j}."
158 | sd_mid_res_prefix = f"middle_block.{3*j}."
159 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
160 |
161 | # Temporal
162 | hf_mid_atn_prefix = "mid_block.temp_attentions.0."
163 | sd_mid_atn_prefix = "middle_block.2."
164 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
165 |
166 | for j in range(2):
167 | hf_mid_res_prefix = f"mid_block.temp_convs.{j}."
168 | sd_mid_res_prefix = f"middle_block.{3*j}.temopral_conv."
169 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
170 |
171 | # The pipeline
172 | def convert_unet_state_dict(unet_state_dict, strict_mapping=False):
173 | print ('Converting the UNET')
174 | # buyer beware: this is a *brittle* function,
175 | # and correct output requires that all of these pieces interact in
176 | # the exact order in which I have arranged them.
177 | mapping = {k: k for k in unet_state_dict.keys()}
178 |
179 | for sd_name, hf_name in unet_conversion_map:
180 | if strict_mapping:
181 | if hf_name in mapping:
182 | mapping[hf_name] = sd_name
183 | else:
184 | mapping[hf_name] = sd_name
185 | for k, v in mapping.items():
186 | if "resnets" in k:
187 | for sd_part, hf_part in unet_conversion_map_resnet:
188 | v = v.replace(hf_part, sd_part)
189 | mapping[k] = v
190 | # elif "temp_convs" in k:
191 | # for sd_part, hf_part in unet_conversion_map_resnet:
192 | # v = v.replace(hf_part, sd_part)
193 | # mapping[k] = v
194 | for k, v in mapping.items():
195 | for sd_part, hf_part in unet_conversion_map_layer:
196 | v = v.replace(hf_part, sd_part)
197 | mapping[k] = v
198 |
199 |
200 | # there must be a pattern, but I don't want to bother atm
201 | do_not_unsqueeze = [f'output_blocks.{i}.1.proj_out.weight' for i in range(3, 12)] + [f'output_blocks.{i}.1.proj_in.weight' for i in range(3, 12)] + ['middle_block.1.proj_in.weight', 'middle_block.1.proj_out.weight'] + [f'input_blocks.{i}.1.proj_out.weight' for i in [1, 2, 4, 5, 7, 8]] + [f'input_blocks.{i}.1.proj_in.weight' for i in [1, 2, 4, 5, 7, 8]]
202 | print (do_not_unsqueeze)
203 |
204 | new_state_dict = {v: (unet_state_dict[k].unsqueeze(-1) if ('proj_' in k and ('bias' not in k) and (k not in do_not_unsqueeze)) else unet_state_dict[k]) for k, v in mapping.items()}
205 | # HACK: idk why the hell it does not work with list comprehension
206 | for k, v in new_state_dict.items():
207 | has_k = False
208 | for n in do_not_unsqueeze:
209 | if k == n:
210 | has_k = True
211 |
212 | if has_k:
213 | v = v.squeeze(-1)
214 | new_state_dict[k] = v
215 |
216 | return new_state_dict
217 |
218 | # TODO: VAE conversion. We doesn't train it in the most cases, but may be handy for the future --kabachuha
219 |
220 | # =========================#
221 | # Text Encoder Conversion #
222 | # =========================#
223 |
224 | # IT IS THE SAME CLIP ENCODER, SO JUST COPYPASTING IT --kabachuha
225 |
226 | # =========================#
227 | # Text Encoder Conversion #
228 | # =========================#
229 |
230 |
231 | textenc_conversion_lst = [
232 | # (stable-diffusion, HF Diffusers)
233 | ("resblocks.", "text_model.encoder.layers."),
234 | ("ln_1", "layer_norm1"),
235 | ("ln_2", "layer_norm2"),
236 | (".c_fc.", ".fc1."),
237 | (".c_proj.", ".fc2."),
238 | (".attn", ".self_attn"),
239 | ("ln_final.", "transformer.text_model.final_layer_norm."),
240 | ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
241 | ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
242 | ]
243 | protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
244 | textenc_pattern = re.compile("|".join(protected.keys()))
245 |
246 | # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
247 | code2idx = {"q": 0, "k": 1, "v": 2}
248 |
249 |
250 | def convert_text_enc_state_dict_v20(text_enc_dict):
251 | #print ('Converting the text encoder')
252 | new_state_dict = {}
253 | capture_qkv_weight = {}
254 | capture_qkv_bias = {}
255 | for k, v in text_enc_dict.items():
256 | if (
257 | k.endswith(".self_attn.q_proj.weight")
258 | or k.endswith(".self_attn.k_proj.weight")
259 | or k.endswith(".self_attn.v_proj.weight")
260 | ):
261 | k_pre = k[: -len(".q_proj.weight")]
262 | k_code = k[-len("q_proj.weight")]
263 | if k_pre not in capture_qkv_weight:
264 | capture_qkv_weight[k_pre] = [None, None, None]
265 | capture_qkv_weight[k_pre][code2idx[k_code]] = v
266 | continue
267 |
268 | if (
269 | k.endswith(".self_attn.q_proj.bias")
270 | or k.endswith(".self_attn.k_proj.bias")
271 | or k.endswith(".self_attn.v_proj.bias")
272 | ):
273 | k_pre = k[: -len(".q_proj.bias")]
274 | k_code = k[-len("q_proj.bias")]
275 | if k_pre not in capture_qkv_bias:
276 | capture_qkv_bias[k_pre] = [None, None, None]
277 | capture_qkv_bias[k_pre][code2idx[k_code]] = v
278 | continue
279 |
280 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
281 | new_state_dict[relabelled_key] = v
282 |
283 | for k_pre, tensors in capture_qkv_weight.items():
284 | if None in tensors:
285 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
286 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
287 | new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
288 |
289 | for k_pre, tensors in capture_qkv_bias.items():
290 | if None in tensors:
291 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
292 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
293 | new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
294 |
295 | return new_state_dict
296 |
297 |
298 | def convert_text_enc_state_dict(text_enc_dict):
299 | return text_enc_dict
300 |
301 | textenc_conversion_lst = [
302 | # (stable-diffusion, HF Diffusers)
303 | ("resblocks.", "text_model.encoder.layers."),
304 | ("ln_1", "layer_norm1"),
305 | ("ln_2", "layer_norm2"),
306 | (".c_fc.", ".fc1."),
307 | (".c_proj.", ".fc2."),
308 | (".attn", ".self_attn"),
309 | ("ln_final.", "transformer.text_model.final_layer_norm."),
310 | ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
311 | ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
312 | ]
313 | protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
314 | textenc_pattern = re.compile("|".join(protected.keys()))
315 |
316 | # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
317 | code2idx = {"q": 0, "k": 1, "v": 2}
318 |
319 |
320 | def convert_text_enc_state_dict_v20(text_enc_dict):
321 | new_state_dict = {}
322 | capture_qkv_weight = {}
323 | capture_qkv_bias = {}
324 | for k, v in text_enc_dict.items():
325 | if (
326 | k.endswith(".self_attn.q_proj.weight")
327 | or k.endswith(".self_attn.k_proj.weight")
328 | or k.endswith(".self_attn.v_proj.weight")
329 | ):
330 | k_pre = k[: -len(".q_proj.weight")]
331 | k_code = k[-len("q_proj.weight")]
332 | if k_pre not in capture_qkv_weight:
333 | capture_qkv_weight[k_pre] = [None, None, None]
334 | capture_qkv_weight[k_pre][code2idx[k_code]] = v
335 | continue
336 |
337 | if (
338 | k.endswith(".self_attn.q_proj.bias")
339 | or k.endswith(".self_attn.k_proj.bias")
340 | or k.endswith(".self_attn.v_proj.bias")
341 | ):
342 | k_pre = k[: -len(".q_proj.bias")]
343 | k_code = k[-len("q_proj.bias")]
344 | if k_pre not in capture_qkv_bias:
345 | capture_qkv_bias[k_pre] = [None, None, None]
346 | capture_qkv_bias[k_pre][code2idx[k_code]] = v
347 | continue
348 |
349 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
350 | new_state_dict[relabelled_key] = v
351 |
352 | for k_pre, tensors in capture_qkv_weight.items():
353 | if None in tensors:
354 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
355 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
356 | new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
357 |
358 | for k_pre, tensors in capture_qkv_bias.items():
359 | if None in tensors:
360 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
361 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
362 | new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
363 |
364 | return new_state_dict
365 |
366 |
367 | def convert_text_enc_state_dict(text_enc_dict):
368 | return text_enc_dict
369 |
370 | if __name__ == "__main__":
371 | parser = argparse.ArgumentParser()
372 |
373 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
374 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
375 | parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.")
376 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
377 | parser.add_argument(
378 | "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
379 | )
380 |
381 | args = parser.parse_args()
382 |
383 | assert args.model_path is not None, "Must provide a model path!"
384 |
385 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
386 |
387 | assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!"
388 |
389 | # Path for safetensors
390 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
391 | #vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
392 | text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
393 |
394 | # Load models from safetensors if it exists, if it doesn't pytorch
395 | if osp.exists(unet_path):
396 | unet_state_dict = load_file(unet_path, device="cpu")
397 | else:
398 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
399 | unet_state_dict = torch.load(unet_path, map_location="cpu")
400 |
401 | # if osp.exists(vae_path):
402 | # vae_state_dict = load_file(vae_path, device="cpu")
403 | # else:
404 | # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
405 | # vae_state_dict = torch.load(vae_path, map_location="cpu")
406 |
407 | if osp.exists(text_enc_path):
408 | text_enc_dict = load_file(text_enc_path, device="cpu")
409 | else:
410 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
411 | text_enc_dict = torch.load(text_enc_path, map_location="cpu")
412 |
413 | # Convert the UNet model
414 | unet_state_dict = convert_unet_state_dict(unet_state_dict)
415 | #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
416 |
417 | # Convert the VAE model
418 | # vae_state_dict = convert_vae_state_dict(vae_state_dict)
419 | # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
420 |
421 | # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
422 | is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
423 |
424 | if is_v20_model:
425 |
426 | # MODELSCOPE always uses the 2.X encoder, btw --kabachuha
427 |
428 | # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
429 | text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
430 | text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
431 | #text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
432 | else:
433 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
434 | #text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
435 |
436 | # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
437 | # Save CLIP and the Diffusion model to their own files
438 |
439 | #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
440 | print ('Saving UNET')
441 | state_dict = {**unet_state_dict}
442 |
443 | if args.half:
444 | state_dict = {k: v.half() for k, v in state_dict.items()}
445 |
446 | if args.use_safetensors:
447 | save_file(state_dict, args.checkpoint_path)
448 | else:
449 | #state_dict = {"state_dict": state_dict}
450 | torch.save(state_dict, args.checkpoint_path)
451 |
452 | # TODO: CLIP conversion doesn't work atm
453 | # print ('Saving CLIP')
454 | # state_dict = {**text_enc_dict}
455 |
456 | # if args.half:
457 | # state_dict = {k: v.half() for k, v in state_dict.items()}
458 |
459 | # if args.use_safetensors:
460 | # save_file(state_dict, args.checkpoint_path)
461 | # else:
462 | # #state_dict = {"state_dict": state_dict}
463 | # torch.save(state_dict, args.clip_checkpoint_path)
464 |
465 | print('Operation successfull')
466 |
--------------------------------------------------------------------------------
/utils/lama.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on the implementation from:
3 | https://huggingface.co/spaces/fffiloni/lama-video-watermark-remover/tree/main
4 |
5 | Modules were adapted by Hans Brouwer to only support the final configuration of the model uploaded here:
6 | https://huggingface.co/akhaliq/lama
7 |
8 | Apache License 2.0: https://github.com/advimman/lama/blob/main/LICENSE
9 |
10 | @article{suvorov2021resolution,
11 | title={Resolution-robust Large Mask Inpainting with Fourier Convolutions},
12 | author={Suvorov, Roman and Logacheva, Elizaveta and Mashikhin, Anton and Remizova, Anastasia and Ashukha, Arsenii and Silvestrov, Aleksei and Kong, Naejin and Goka, Harshith and Park, Kiwoong and Lempitsky, Victor},
13 | journal={arXiv preprint arXiv:2109.07161},
14 | year={2021}
15 | }
16 | """
17 |
18 | import os
19 | import sys
20 | from urllib.request import urlretrieve
21 |
22 | import torch
23 | from einops import rearrange
24 | from PIL import Image
25 | from torch import nn
26 | from torch.nn import functional as F
27 | from torchvision.transforms.functional import to_tensor
28 | from tqdm import tqdm
29 |
30 | from train import export_to_video
31 |
32 |
33 | LAMA_URL = "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt"
34 | LAMA_PATH = "models/lama.ckpt"
35 |
36 |
37 | def download_progress(t):
38 | last_b = [0]
39 |
40 | def update_to(b=1, bsize=1, tsize=None):
41 | if tsize is not None:
42 | t.total = tsize
43 | t.update((b - last_b[0]) * bsize)
44 | last_b[0] = b
45 |
46 | return update_to
47 |
48 |
49 | def download(url, path):
50 | with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=path) as t:
51 | urlretrieve(url, filename=path, reporthook=download_progress(t), data=None)
52 |
53 |
54 | class FourierUnit(nn.Module):
55 | def __init__(self, in_channels, out_channels, groups=1):
56 | super(FourierUnit, self).__init__()
57 | self.groups = groups
58 | self.conv_layer = torch.nn.Conv2d(
59 | in_channels=in_channels * 2,
60 | out_channels=out_channels * 2,
61 | kernel_size=1,
62 | stride=1,
63 | padding=0,
64 | groups=self.groups,
65 | bias=False,
66 | )
67 | self.bn = torch.nn.BatchNorm2d(out_channels * 2)
68 | self.relu = torch.nn.ReLU(inplace=True)
69 |
70 | def forward(self, x):
71 | batch = x.shape[0]
72 |
73 | # (batch, c, h, w/2+1, 2)
74 | fft_dim = (-2, -1)
75 | ffted = torch.fft.rfftn(x, dim=fft_dim, norm="ortho")
76 | ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
77 | ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
78 | ffted = ffted.view((batch, -1) + ffted.size()[3:])
79 |
80 | ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
81 | ffted = self.relu(self.bn(ffted))
82 |
83 | # (batch,c, t, h, w/2+1, 2)
84 | ffted = ffted.view((batch, -1, 2) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous()
85 | ffted = torch.complex(ffted[..., 0], ffted[..., 1])
86 |
87 | ifft_shape_slice = x.shape[-2:]
88 | output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm="ortho")
89 |
90 | return output
91 |
92 |
93 | class SpectralTransform(nn.Module):
94 | def __init__(self, in_channels, out_channels, stride=1, groups=1):
95 | super(SpectralTransform, self).__init__()
96 | self.stride = stride
97 | if stride == 2:
98 | self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
99 | else:
100 | self.downsample = nn.Identity()
101 |
102 | self.conv1 = nn.Sequential(
103 | nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False),
104 | nn.BatchNorm2d(out_channels // 2),
105 | nn.ReLU(inplace=True),
106 | )
107 | self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups)
108 | self.conv2 = torch.nn.Conv2d(out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
109 |
110 | def forward(self, x):
111 | x = self.downsample(x)
112 | x = self.conv1(x)
113 | output = self.fu(x)
114 | output = self.conv2(x + output)
115 | return output
116 |
117 |
118 | class FFC(nn.Module):
119 | def __init__(
120 | self,
121 | in_channels,
122 | out_channels,
123 | kernel_size,
124 | ratio_gin,
125 | ratio_gout,
126 | stride=1,
127 | padding=0,
128 | dilation=1,
129 | groups=1,
130 | bias=False,
131 | padding_type="reflect",
132 | gated=False,
133 | ):
134 | super(FFC, self).__init__()
135 |
136 | assert stride == 1 or stride == 2, "Stride should be 1 or 2."
137 | self.stride = stride
138 |
139 | in_cg = int(in_channels * ratio_gin)
140 | in_cl = in_channels - in_cg
141 | out_cg = int(out_channels * ratio_gout)
142 | out_cl = out_channels - out_cg
143 |
144 | self.ratio_gin = ratio_gin
145 | self.ratio_gout = ratio_gout
146 | self.global_in_num = in_cg
147 |
148 | module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
149 | self.convl2l = module(
150 | in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
151 | )
152 | module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
153 | self.convl2g = module(
154 | in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
155 | )
156 | module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
157 | self.convg2l = module(
158 | in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
159 | )
160 | module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
161 | self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2)
162 |
163 | self.gated = gated
164 | module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
165 | self.gate = module(in_channels, 2, 1)
166 |
167 | def forward(self, x):
168 | x_l, x_g = x if type(x) is tuple else (x, 0)
169 | out_xl, out_xg = 0, 0
170 |
171 | if self.gated:
172 | total_input_parts = [x_l]
173 | if torch.is_tensor(x_g):
174 | total_input_parts.append(x_g)
175 | total_input = torch.cat(total_input_parts, dim=1)
176 |
177 | gates = torch.sigmoid(self.gate(total_input))
178 | g2l_gate, l2g_gate = gates.chunk(2, dim=1)
179 | else:
180 | g2l_gate, l2g_gate = 1, 1
181 |
182 | if self.ratio_gout != 1:
183 | out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
184 | if self.ratio_gout != 0:
185 | out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
186 |
187 | return out_xl, out_xg
188 |
189 |
190 | class FFC_BN_ACT(nn.Module):
191 | def __init__(
192 | self,
193 | in_channels,
194 | out_channels,
195 | kernel_size,
196 | ratio_gin=0,
197 | ratio_gout=0,
198 | stride=1,
199 | padding=0,
200 | dilation=1,
201 | groups=1,
202 | bias=False,
203 | norm_layer=nn.BatchNorm2d,
204 | activation_layer=nn.ReLU,
205 | ):
206 | super(FFC_BN_ACT, self).__init__()
207 | self.ffc = FFC(
208 | in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding, dilation, groups, bias
209 | )
210 | lnorm = nn.Identity if ratio_gout == 1 else norm_layer
211 | gnorm = nn.Identity if ratio_gout == 0 else norm_layer
212 | global_channels = int(out_channels * ratio_gout)
213 | self.bn_l = lnorm(out_channels - global_channels)
214 | self.bn_g = gnorm(global_channels)
215 |
216 | lact = nn.Identity if ratio_gout == 1 else activation_layer
217 | gact = nn.Identity if ratio_gout == 0 else activation_layer
218 | self.act_l = lact(inplace=True)
219 | self.act_g = gact(inplace=True)
220 |
221 | def forward(self, x):
222 | x_l, x_g = self.ffc(x)
223 | x_l = self.act_l(self.bn_l(x_l))
224 | x_g = self.act_g(self.bn_g(x_g))
225 | return x_l, x_g
226 |
227 |
228 | class FFCResnetBlock(nn.Module):
229 | def __init__(self, dim, ratio_gin, ratio_gout):
230 | super().__init__()
231 | self.conv1 = FFC_BN_ACT(
232 | dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout
233 | )
234 | self.conv2 = FFC_BN_ACT(
235 | dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout
236 | )
237 |
238 | def forward(self, x):
239 | x_l, x_g = x if type(x) is tuple else (x, 0)
240 | id_l, id_g = x_l, x_g
241 | x_l, x_g = self.conv1((x_l, x_g))
242 | x_l, x_g = self.conv2((x_l, x_g))
243 | x_l, x_g = id_l + x_l, id_g + x_g
244 | out = x_l, x_g
245 | return out
246 |
247 |
248 | class ConcatTupleLayer(nn.Module):
249 | def forward(self, x):
250 | assert isinstance(x, tuple)
251 | x_l, x_g = x
252 | assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
253 | if not torch.is_tensor(x_g):
254 | return x_l
255 | return torch.cat(x, dim=1)
256 |
257 |
258 | class LargeMaskInpainting(nn.Module):
259 | def __init__(self, input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=18, max_features=1024):
260 | super().__init__()
261 |
262 | model = [nn.ReflectionPad2d(3), FFC_BN_ACT(input_nc, ngf, kernel_size=7)]
263 |
264 | ### downsample
265 | for i in range(n_downsampling):
266 | mult = 2**i
267 | model += [
268 | FFC_BN_ACT(
269 | min(max_features, ngf * mult),
270 | min(max_features, ngf * mult * 2),
271 | kernel_size=3,
272 | stride=2,
273 | padding=1,
274 | ratio_gout=0.75 if i == n_downsampling - 1 else 0,
275 | )
276 | ]
277 |
278 | ### resnet blocks
279 | for i in range(n_blocks):
280 | cur_resblock = FFCResnetBlock(min(max_features, ngf * 2**n_downsampling), ratio_gin=0.75, ratio_gout=0.75)
281 | model += [cur_resblock]
282 |
283 | model += [ConcatTupleLayer()]
284 |
285 | ### upsample
286 | for i in range(n_downsampling):
287 | mult = 2 ** (n_downsampling - i)
288 | model += [
289 | nn.ConvTranspose2d(
290 | min(max_features, ngf * mult),
291 | min(max_features, int(ngf * mult / 2)),
292 | kernel_size=3,
293 | stride=2,
294 | padding=1,
295 | output_padding=1,
296 | ),
297 | nn.BatchNorm2d(min(max_features, int(ngf * mult / 2))),
298 | nn.ReLU(True),
299 | ]
300 |
301 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7), nn.Sigmoid()]
302 | self.model = nn.Sequential(*model)
303 |
304 | def forward(self, img, mask):
305 | masked_img = img * (1 - mask)
306 | masked_img = torch.cat([masked_img, mask], dim=1)
307 | pred = self.model(masked_img)
308 | inpainted = mask * pred + (1 - mask) * img
309 | return inpainted
310 |
311 |
312 | @torch.inference_mode()
313 | def inpaint_watermark(imgs):
314 | if not os.path.exists(LAMA_PATH):
315 | download(LAMA_URL, LAMA_PATH)
316 |
317 | mask = to_tensor(Image.open("./utils/mask.png").convert("L")).unsqueeze(0).to(imgs.device)
318 | if mask.shape[-1] != imgs.shape[-1]:
319 | mask = F.interpolate(mask, size=(imgs.shape[2], imgs.shape[3]), mode="nearest")
320 | mask = mask.expand(imgs.shape[0], 1, mask.shape[2], mask.shape[3])
321 |
322 | model = LargeMaskInpainting().to(imgs.device)
323 | state_dict = torch.load(LAMA_PATH, map_location=imgs.device)["state_dict"]
324 | g_dict = {k.replace("generator.", ""): v for k, v in state_dict.items() if k.startswith("generator")}
325 | model.load_state_dict(g_dict)
326 |
327 | inpainted = model.forward(imgs, mask)
328 |
329 | return inpainted
330 |
331 |
332 | if __name__ == "__main__":
333 | import decord
334 |
335 | decord.bridge.set_bridge("torch")
336 |
337 | if len(sys.argv) < 2:
338 | print("Usage: python -m utils.lama ")
339 | sys.exit(1)
340 |
341 | video_path = sys.argv[1]
342 | out_path = video_path.replace(".mp4", " inpainted.mp4")
343 |
344 | vr = decord.VideoReader(video_path)
345 | fps = vr.get_avg_fps()
346 | video = rearrange(vr[:], "f h w c -> f c h w").div(255)
347 |
348 | inpainted = inpaint_watermark(video)
349 | inpainted = rearrange(inpainted, "f c h w -> f h w c").clamp(0, 1).mul(255).byte().cpu().numpy()
350 | export_to_video(inpainted, out_path, fps)
351 |
--------------------------------------------------------------------------------
/utils/lora_handler.py:
--------------------------------------------------------------------------------
1 | import os
2 | from logging import warnings
3 | import torch
4 | from typing import Union
5 | from types import SimpleNamespace
6 | from models.unet_3d_condition_mask import UNet3DConditionModel
7 | from transformers import CLIPTextModel
8 | from utils.convert_diffusers_to_original_ms_text_to_video import convert_unet_state_dict, convert_text_enc_state_dict_v20
9 |
10 | from .lora import (
11 | extract_lora_ups_down,
12 | inject_trainable_lora_extended,
13 | save_lora_weight,
14 | train_patch_pipe,
15 | monkeypatch_or_replace_lora,
16 | monkeypatch_or_replace_lora_extended
17 | )
18 |
19 | from stable_lora.lora import (
20 | activate_lora_train,
21 | add_lora_to,
22 | save_lora,
23 | load_lora,
24 | set_mode_group
25 | )
26 |
27 | FILE_BASENAMES = ['unet', 'text_encoder']
28 | LORA_FILE_TYPES = ['.pt', '.safetensors']
29 | CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r']
30 | STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias']
31 |
32 | lora_versions = dict(
33 | stable_lora = "stable_lora",
34 | cloneofsimo = "cloneofsimo"
35 | )
36 |
37 | lora_func_types = dict(
38 | loader = "loader",
39 | injector = "injector"
40 | )
41 |
42 | lora_args = dict(
43 | model = None,
44 | loras = None,
45 | target_replace_module = [],
46 | target_module = [],
47 | r = 4,
48 | search_class = [torch.nn.Linear],
49 | dropout = 0,
50 | lora_bias = 'none'
51 | )
52 |
53 | LoraVersions = SimpleNamespace(**lora_versions)
54 | LoraFuncTypes = SimpleNamespace(**lora_func_types)
55 |
56 | LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
57 | LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
58 |
59 | def filter_dict(_dict, keys=[]):
60 | if len(keys) == 0:
61 | assert "Keys cannot empty for filtering return dict."
62 |
63 | for k in keys:
64 | if k not in lora_args.keys():
65 | assert f"{k} does not exist in available LoRA arguments"
66 |
67 | return {k: v for k, v in _dict.items() if k in keys}
68 |
69 | class LoraHandler(object):
70 | def __init__(
71 | self,
72 | version: LORA_VERSIONS = LoraVersions.cloneofsimo,
73 | use_unet_lora: bool = False,
74 | use_text_lora: bool = False,
75 | save_for_webui: bool = False,
76 | only_for_webui: bool = False,
77 | lora_bias: str = 'none',
78 | unet_replace_modules: list = ['UNet3DConditionModel'],
79 | text_encoder_replace_modules: list = ['CLIPEncoderLayer']
80 | ):
81 | self.version = version
82 | self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
83 | self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
84 | self.lora_bias = lora_bias
85 | self.use_unet_lora = use_unet_lora
86 | self.use_text_lora = use_text_lora
87 | self.save_for_webui = save_for_webui
88 | self.only_for_webui = only_for_webui
89 | self.unet_replace_modules = unet_replace_modules
90 | self.text_encoder_replace_modules = text_encoder_replace_modules
91 | self.use_lora = any([use_text_lora, use_unet_lora])
92 |
93 | if self.use_lora:
94 | print(f"Using LoRA Version: {self.version}")
95 |
96 | def is_cloneofsimo_lora(self):
97 | return self.version == LoraVersions.cloneofsimo
98 |
99 | def is_stable_lora(self):
100 | return self.version == LoraVersions.stable_lora
101 |
102 | def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader):
103 |
104 | if self.is_cloneofsimo_lora():
105 |
106 | if func_type == LoraFuncTypes.loader:
107 | return monkeypatch_or_replace_lora_extended
108 |
109 | if func_type == LoraFuncTypes.injector:
110 | return inject_trainable_lora_extended
111 |
112 | if self.is_stable_lora():
113 |
114 | if func_type == LoraFuncTypes.loader:
115 | return load_lora
116 |
117 | if func_type == LoraFuncTypes.injector:
118 | return add_lora_to
119 |
120 | assert "LoRA Version does not exist."
121 |
122 | def check_lora_ext(self, lora_file: str):
123 | return lora_file.endswith(tuple(LORA_FILE_TYPES))
124 |
125 | def get_lora_file_path(
126 | self,
127 | lora_path: str,
128 | model: Union[UNet3DConditionModel, CLIPTextModel]
129 | ):
130 | if os.path.exists(lora_path):
131 | lora_filenames = [fns for fns in os.listdir(lora_path)]
132 | is_lora = self.check_lora_ext(lora_path)
133 |
134 | is_unet = isinstance(model, UNet3DConditionModel)
135 | is_text = isinstance(model, CLIPTextModel)
136 | idx = 0 if is_unet else 1
137 |
138 | base_name = FILE_BASENAMES[idx]
139 |
140 | for lora_filename in lora_filenames:
141 | is_lora = self.check_lora_ext(lora_filename)
142 | if not is_lora:
143 | continue
144 |
145 | if base_name in lora_filename:
146 | return os.path.join(lora_path, lora_filename)
147 |
148 | return None
149 |
150 | def handle_lora_load(self, file_name:str, lora_loader_args: dict = None):
151 | self.lora_loader(**lora_loader_args)
152 | print(f"Successfully loaded LoRA from: {file_name}")
153 |
154 | def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,):
155 | try:
156 | lora_file = self.get_lora_file_path(lora_path, model)
157 |
158 | if lora_file is not None:
159 | lora_loader_args.update({"lora_path": lora_file})
160 | self.handle_lora_load(lora_file, lora_loader_args)
161 |
162 | else:
163 | print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...")
164 |
165 | except Exception as e:
166 | print(f"An error occured while loading a LoRA file: {e}")
167 |
168 | def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias):
169 | return_dict = lora_args.copy()
170 |
171 | if self.is_cloneofsimo_lora():
172 | return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
173 | return_dict.update({
174 | "model": model,
175 | "loras": self.get_lora_file_path(lora_path, model),
176 | "target_replace_module": replace_modules,
177 | "r": r
178 | })
179 |
180 | if self.is_stable_lora():
181 | KEYS = ['model', 'lora_path']
182 | return_dict = filter_dict(return_dict, KEYS)
183 |
184 | return_dict.update({'model': model, 'lora_path': lora_path})
185 |
186 | return return_dict
187 |
188 | def do_lora_injection(
189 | self,
190 | model,
191 | replace_modules,
192 | bias='none',
193 | dropout=0,
194 | r=4,
195 | lora_loader_args=None,
196 | ):
197 | REPLACE_MODULES = replace_modules
198 |
199 | params = None
200 | negation = None
201 | is_injection_hybrid = False
202 |
203 | if self.is_cloneofsimo_lora():
204 | is_injection_hybrid = True
205 | injector_args = lora_loader_args
206 |
207 | params, negation = self.lora_injector(**injector_args)
208 | for _up, _down in extract_lora_ups_down(
209 | model,
210 | target_replace_module=REPLACE_MODULES):
211 |
212 | if all(x is not None for x in [_up, _down]):
213 | print(f"Lora successfully injected into {model.__class__.__name__}.")
214 |
215 | break
216 |
217 | return params, negation, is_injection_hybrid
218 |
219 | if self.is_stable_lora():
220 | injector_args = lora_args.copy()
221 | injector_args = filter_dict(injector_args, keys=STABLE_LORA_KEYS)
222 |
223 | SEARCH_CLASS = [torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Embedding]
224 |
225 | injector_args.update({
226 | "model": model,
227 | "target_module": REPLACE_MODULES,
228 | "search_class": SEARCH_CLASS,
229 | "r": r,
230 | "dropout": dropout,
231 | "lora_bias": self.lora_bias
232 | })
233 |
234 | activator = self.lora_injector(**injector_args)
235 | activator()
236 |
237 | return params, negation, is_injection_hybrid
238 |
239 | def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16):
240 |
241 | params = None
242 | negation = None
243 |
244 | lora_loader_args = self.get_lora_func_args(
245 | lora_path,
246 | use_lora,
247 | model,
248 | replace_modules,
249 | r,
250 | dropout,
251 | self.lora_bias
252 | )
253 | if use_lora:
254 | params, negation, is_injection_hybrid = self.do_lora_injection(
255 | model,
256 | replace_modules,
257 | bias=self.lora_bias,
258 | lora_loader_args=lora_loader_args,
259 | dropout=dropout,
260 | r=r
261 | )
262 |
263 | if not is_injection_hybrid:
264 | self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args)
265 |
266 | params = model if params is None else params
267 | return params, negation
268 |
269 |
270 | def deactivate_lora_train(self, models, deactivate=True):
271 | """
272 | Usage: Use before and after sampling previews.
273 | Currently only available for Stable LoRA.
274 | """
275 | if self.is_stable_lora():
276 | set_mode_group(models, not deactivate)
277 |
278 | def save_cloneofsimo_lora(self, model, save_path, step):
279 |
280 | def save_lora(model, name, condition, replace_modules, step, save_path):
281 | if condition and replace_modules is not None:
282 | save_path = f"{save_path}/{step}_{name}.pt"
283 | save_lora_weight(model, save_path, replace_modules)
284 |
285 | save_lora(
286 | model.unet,
287 | FILE_BASENAMES[0],
288 | self.use_unet_lora,
289 | self.unet_replace_modules,
290 | step,
291 | save_path,
292 | )
293 | save_lora(
294 | model.text_encoder,
295 | FILE_BASENAMES[1],
296 | self.use_text_lora,
297 | self.text_encoder_replace_modules,
298 | step,
299 | save_path
300 | )
301 |
302 | train_patch_pipe(model, self.use_unet_lora, self.use_text_lora)
303 |
304 | def save_stable_lora(
305 | self,
306 | model,
307 | step,
308 | name,
309 | save_path = '',
310 | save_for_webui=False,
311 | only_for_webui=False
312 | ):
313 | import uuid
314 |
315 | save_filename = f"{step}_{name}"
316 | lora_metadata = metadata = {
317 | "stable_lora_text_to_video": "v1",
318 | "lora_name": name + "_" + uuid.uuid4().hex.lower()[:5]
319 | }
320 | save_lora(
321 | unet=model.unet,
322 | text_encoder=model.text_encoder,
323 | save_text_weights=self.use_text_lora,
324 | output_dir=save_path,
325 | lora_filename=save_filename,
326 | lora_bias=self.lora_bias,
327 | save_for_webui=self.save_for_webui,
328 | only_webui=self.only_for_webui,
329 | metadata=lora_metadata,
330 | unet_dict_converter=convert_unet_state_dict,
331 | text_dict_converter=convert_text_enc_state_dict_v20
332 | )
333 |
334 | def save_lora_weights(self, model: None, save_path: str ='',step: str = ''):
335 | save_path = f"{save_path}/lora"
336 | os.makedirs(save_path, exist_ok=True)
337 |
338 | if self.is_cloneofsimo_lora():
339 | if any([self.save_for_webui, self.only_for_webui]):
340 | warnings.warn(
341 | """
342 | You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention.
343 | Only 'stable_lora' is supported for saving to a compatible webui file.
344 | """
345 | )
346 | self.save_cloneofsimo_lora(model, save_path, step)
347 |
348 | if self.is_stable_lora():
349 | name = 'lora_text_to_video'
350 | self.save_stable_lora(model, step, name, save_path)
--------------------------------------------------------------------------------
/utils/ptp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 |
18 | from PIL import Image, ImageDraw, ImageFont
19 | import cv2
20 | import abc
21 | from typing import Optional, Union, Tuple, List, Callable, Dict
22 | #from IPython.display import display
23 | from tqdm.notebook import tqdm
24 | from diffusers.models.cross_attention import CrossAttention
25 | from diffusers.utils import PIL_INTERPOLATION
26 |
27 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
28 | h, w, c = image.shape
29 | offset = int(h * .2)
30 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
31 | font = cv2.FONT_HERSHEY_SIMPLEX
32 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
33 | img[:h] = image
34 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
35 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
36 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
37 | return img
38 |
39 |
40 | def prepare_image(image, width, height, batch_size, num_videos_per_prompt, device, dtype, do_classifier_free_guidance=True
41 | ):
42 | if not isinstance(image, torch.Tensor):
43 | if isinstance(image, Image.Image):
44 | image = [image]
45 |
46 | if isinstance(image[0], Image.Image):
47 | images = []
48 |
49 | for image_ in image:
50 | image_ = image_.convert("RGB")
51 | image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
52 | image_ = np.array(image_)
53 | image_ = image_[None, :]
54 | images.append(image_)
55 |
56 | image = images
57 |
58 | image = np.concatenate(image, axis=0)
59 | image = np.array(image).astype(np.float32) / 255.0
60 | image = image.transpose(0, 3, 1, 2)
61 | image = torch.from_numpy(image)
62 | elif isinstance(image[0], torch.Tensor):
63 | image = torch.cat(image, dim=0)
64 |
65 | image_batch_size = image.shape[0]
66 |
67 | if image_batch_size == 1:
68 | repeat_by = batch_size
69 | else:
70 | # image batch size is the same as prompt batch size
71 | repeat_by = num_videos_per_prompt
72 |
73 | image = image.repeat_interleave(repeat_by, dim=0)
74 |
75 | image = image.to(device=device, dtype=dtype)
76 |
77 | if do_classifier_free_guidance:
78 | image = torch.cat([image] * 2)
79 |
80 | return image
81 |
82 | def view_images(images, save_name, num_rows=1, offset_ratio=0.02):
83 | if type(images) is list:
84 | num_empty = len(images) % num_rows
85 | elif images.ndim == 4:
86 | num_empty = images.shape[0] % num_rows
87 | else:
88 | images = [images]
89 | num_empty = 0
90 |
91 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
92 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
93 | num_items = len(images)
94 | #print(images[0].shape)
95 | h, w, c = images[0].shape
96 | offset = int(h * offset_ratio)
97 | num_cols = num_items // num_rows
98 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
99 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
100 | for i in range(num_rows):
101 | for j in range(num_cols):
102 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
103 | i * num_cols + j]
104 |
105 | pil_img = Image.fromarray(image_)
106 | pil_img.save('output/{}.png'.format(save_name))
107 | #display(pil_img)
108 |
109 |
110 | def diffusion_step(model,latents, context, t, guidance_scale, control_img,low_resource=False,
111 | control=False):
112 | controlnet_conditioning_scale = 1.0
113 | image = prepare_image(
114 | control_img,
115 | 512,
116 | 512,
117 | 1,
118 | 1,
119 | model.device,
120 | model.unet.dtype,
121 | )
122 | if low_resource:
123 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
124 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
125 | else:
126 | latents_input = torch.cat([latents] * 2)
127 | #latents_input = model.scheduler.scale_model_input(latents_input, t)
128 | #print(latent_model_input.shape, context.shape, image.shape)
129 | if control:
130 | down_block_res_samples, mid_block_res_sample = model.controlnet(
131 | latents_input,
132 | t,
133 | encoder_hidden_states=context,
134 | controlnet_cond=image,
135 | return_dict=False,
136 | )
137 | down_block_res_samples = [
138 | down_block_res_sample * controlnet_conditioning_scale
139 | for down_block_res_sample in down_block_res_samples
140 | ]
141 | mid_block_res_sample *= controlnet_conditioning_scale
142 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context,
143 | down_block_additional_residuals=down_block_res_samples,
144 | mid_block_additional_residual=mid_block_res_sample,)["sample"]
145 | else:
146 | noise_pred = model.unet(latents_input,t, encoder_hidden_states=context)["sample"]
147 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
148 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
149 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
150 | #latents = controller.step_callback(latents)
151 | return latents
152 |
153 |
154 | def latent2image(vae, latents):
155 | latents = 1 / 0.18215 * latents
156 | image = vae.decode(latents)['sample']
157 | image = (image / 2 + 0.5).clamp(0, 1)
158 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
159 | image = (image * 255).astype(np.uint8)
160 | return image
161 |
162 |
163 | def init_latent(latent, model, height, width, generator, batch_size):
164 | if latent is None:
165 | latent = torch.randn(
166 | (1, model.unet.in_channels, height // 8, width // 8),
167 | generator=generator,
168 | )
169 | #print(latent.shape,batch_size,model.unet.in_channels)
170 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
171 | return latent, latents
172 |
173 |
174 |
175 |
176 | class AttentionControl(abc.ABC):
177 |
178 | def step_callback(self, x_t):
179 | return x_t
180 |
181 | def between_steps(self):
182 | return
183 |
184 | @property
185 | def num_uncond_att_layers(self):
186 | return 0
187 |
188 | @abc.abstractmethod
189 | def forward (self, attn, is_cross: bool, place_in_unet: str):
190 | raise NotImplementedError
191 |
192 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
193 | #self.reset()
194 | if self.cur_att_layer >= self.num_uncond_att_layers:
195 | h = attn.shape[0]
196 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
197 | self.cur_att_layer += 1
198 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
199 | self.cur_att_layer = 0
200 | self.cur_step += 1
201 | self.between_steps()
202 | return attn
203 |
204 | def reset(self):
205 | self.cur_step = 0
206 | self.cur_att_layer = 0
207 |
208 | def __init__(self):
209 | self.cur_step = 0
210 | self.num_att_layers = -1
211 | self.cur_att_layer = 0
212 |
213 |
214 | class AttentionStore(AttentionControl):
215 |
216 | @staticmethod
217 | def get_empty_store():
218 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
219 | "down_self": [], "mid_self": [], "up_self": []}
220 |
221 | def forward(self, attn, is_cross: bool, place_in_unet: str):
222 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
223 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead
224 | self.step_store[key].append(attn)
225 | return attn
226 |
227 | def between_steps(self):
228 | if len(self.attention_store) == 0:
229 | self.attention_store = self.step_store
230 | else:
231 | for key in self.attention_store:
232 | for i in range(len(self.attention_store[key])):
233 | self.attention_store[key][i] += self.step_store[key][i]
234 | self.step_store = self.get_empty_store()
235 |
236 | def get_average_attention(self):
237 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
238 | return average_attention
239 |
240 |
241 | def reset(self):
242 | super(AttentionStore, self).reset()
243 | self.step_store = self.get_empty_store()
244 | self.attention_store = {}
245 |
246 | def __init__(self):
247 | super(AttentionStore, self).__init__()
248 | self.step_store = self.get_empty_store()
249 | self.attention_store = {}
250 |
251 |
252 | def load_512(image_path, left=0, right=0, top=0, bottom=0):
253 | if type(image_path) is str:
254 | image = np.array(Image.open(image_path))[:, :, :3]
255 | else:
256 | image = image_path
257 | h, w, c = image.shape
258 | left = min(left, w-1)
259 | right = min(right, w - left - 1)
260 | top = min(top, h - left - 1)
261 | bottom = min(bottom, h - top - 1)
262 | image = image[top:h-bottom, left:w-right]
263 | h, w, c = image.shape
264 | if h < w:
265 | offset = (w - h) // 2
266 | image = image[:, offset:offset + h]
267 | elif w < h:
268 | offset = (h - w) // 2
269 | image = image[offset:offset + w]
270 | image = np.array(Image.fromarray(image).resize((512, 512)))
271 | return image
272 |
273 | class P2PCrossAttnProcessor:
274 |
275 | def __init__(self, controller, place_in_unet):
276 | super().__init__()
277 | self.controller = controller
278 | self.place_in_unet = place_in_unet
279 |
280 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
281 | batch_size, sequence_length, _ = hidden_states.shape
282 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length,batch_size=1)
283 |
284 | query = attn.to_q(hidden_states)
285 |
286 | is_cross = encoder_hidden_states is not None
287 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
288 | key = attn.to_k(encoder_hidden_states)
289 | value = attn.to_v(encoder_hidden_states)
290 |
291 | query = attn.head_to_batch_dim(query)
292 | key = attn.head_to_batch_dim(key)
293 | value = attn.head_to_batch_dim(value)
294 |
295 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
296 |
297 | # one line change
298 | self.controller(attention_probs, is_cross, self.place_in_unet)
299 |
300 | hidden_states = torch.bmm(attention_probs, value)
301 | hidden_states = attn.batch_to_head_dim(hidden_states)
302 |
303 | # linear proj
304 | hidden_states = attn.to_out[0](hidden_states)
305 | # dropout
306 | hidden_states = attn.to_out[1](hidden_states)
307 |
308 | return hidden_states
309 |
310 | def register_attention_control(model, controller,controller1):
311 | attn_procs = {}
312 | cross_att_count = 0
313 | for name in model.unet.attn_processors.keys():
314 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
315 | if name.startswith("mid_block"):
316 | hidden_size = model.unet.config.block_out_channels[-1]
317 | place_in_unet = "mid"
318 | elif name.startswith("up_blocks"):
319 | block_id = int(name[len("up_blocks.")])
320 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
321 | place_in_unet = "up"
322 | elif name.startswith("down_blocks"):
323 | block_id = int(name[len("down_blocks.")])
324 | hidden_size = model.unet.config.block_out_channels[block_id]
325 | place_in_unet = "down"
326 | else:
327 | continue
328 | cross_att_count += 1
329 | attn_procs[name] = P2PCrossAttnProcessor(
330 | controller=controller, place_in_unet=place_in_unet
331 | )
332 | model.unet.set_attn_processor(attn_procs)
333 | controller.num_att_layers = cross_att_count
334 |
335 |
336 | attn_procs = {}
337 | cross_att_count = 0
338 | for name in model.controlnet.attn_processors.keys():
339 | cross_attention_dim = None if name.endswith("attn1.processor") else model.controlnet.config.cross_attention_dim
340 | if name.startswith("mid_block"):
341 | hidden_size = model.controlnet.config.block_out_channels[-1]
342 | place_in_unet = "mid"
343 | elif name.startswith("up_blocks"):
344 | block_id = int(name[len("up_blocks.")])
345 | hidden_size = list(reversed(model.controlnet.config.block_out_channels))[block_id]
346 | place_in_unet = "up"
347 | elif name.startswith("down_blocks"):
348 | block_id = int(name[len("down_blocks.")])
349 | hidden_size = model.controlnet.config.block_out_channels[block_id]
350 | place_in_unet = "down"
351 | else:
352 | continue
353 | cross_att_count += 1
354 | attn_procs[name] = P2PCrossAttnProcessor(
355 | controller=controller1, place_in_unet=place_in_unet
356 | )
357 |
358 | #model.unet.set_attn_processor(attn_procs)
359 | model.controlnet.set_attn_processor(attn_procs)
360 | controller1.num_att_layers = cross_att_count
361 |
362 |
363 | def get_word_inds(text: str, word_place: int, tokenizer):
364 | split_text = text.split(" ")
365 | if type(word_place) is str:
366 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
367 | elif type(word_place) is int:
368 | word_place = [word_place]
369 | out = []
370 | if len(word_place) > 0:
371 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
372 | cur_len, ptr = 0, 0
373 |
374 | for i in range(len(words_encode)):
375 | cur_len += len(words_encode[i])
376 | if ptr in word_place:
377 | out.append(i + 1)
378 | if cur_len >= len(split_text[ptr]):
379 | ptr += 1
380 | cur_len = 0
381 | return np.array(out)
382 |
383 |
384 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
385 | word_inds: Optional[torch.Tensor]=None):
386 | if type(bounds) is float:
387 | bounds = 0, bounds
388 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
389 | if word_inds is None:
390 | word_inds = torch.arange(alpha.shape[2])
391 | alpha[: start, prompt_ind, word_inds] = 0
392 | alpha[start: end, prompt_ind, word_inds] = 1
393 | alpha[end:, prompt_ind, word_inds] = 0
394 | return alpha
395 |
396 |
397 | def get_time_words_attention_alpha(prompts, num_steps,
398 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
399 | tokenizer, max_num_words=77):
400 | if type(cross_replace_steps) is not dict:
401 | cross_replace_steps = {"default_": cross_replace_steps}
402 | if "default_" not in cross_replace_steps:
403 | cross_replace_steps["default_"] = (0., 1.)
404 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
405 | for i in range(len(prompts) - 1):
406 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
407 | i)
408 | for key, item in cross_replace_steps.items():
409 | if key != "default_":
410 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
411 | for i, ind in enumerate(inds):
412 | if len(ind) > 0:
413 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
414 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
415 | return alpha_time_words
416 |
--------------------------------------------------------------------------------
/utils/seq_aligner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import numpy as np
16 |
17 |
18 | class ScoreParams:
19 |
20 | def __init__(self, gap, match, mismatch):
21 | self.gap = gap
22 | self.match = match
23 | self.mismatch = mismatch
24 |
25 | def mis_match_char(self, x, y):
26 | if x != y:
27 | return self.mismatch
28 | else:
29 | return self.match
30 |
31 |
32 | def get_matrix(size_x, size_y, gap):
33 | matrix = []
34 | for i in range(len(size_x) + 1):
35 | sub_matrix = []
36 | for j in range(len(size_y) + 1):
37 | sub_matrix.append(0)
38 | matrix.append(sub_matrix)
39 | for j in range(1, len(size_y) + 1):
40 | matrix[0][j] = j*gap
41 | for i in range(1, len(size_x) + 1):
42 | matrix[i][0] = i*gap
43 | return matrix
44 |
45 |
46 | def get_matrix(size_x, size_y, gap):
47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap
49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap
50 | return matrix
51 |
52 |
53 | def get_traceback_matrix(size_x, size_y):
54 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
55 | matrix[0, 1:] = 1
56 | matrix[1:, 0] = 2
57 | matrix[0, 0] = 4
58 | return matrix
59 |
60 |
61 | def global_align(x, y, score):
62 | matrix = get_matrix(len(x), len(y), score.gap)
63 | trace_back = get_traceback_matrix(len(x), len(y))
64 | for i in range(1, len(x) + 1):
65 | for j in range(1, len(y) + 1):
66 | left = matrix[i, j - 1] + score.gap
67 | up = matrix[i - 1, j] + score.gap
68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
69 | matrix[i, j] = max(left, up, diag)
70 | if matrix[i, j] == left:
71 | trace_back[i, j] = 1
72 | elif matrix[i, j] == up:
73 | trace_back[i, j] = 2
74 | else:
75 | trace_back[i, j] = 3
76 | return matrix, trace_back
77 |
78 |
79 | def get_aligned_sequences(x, y, trace_back):
80 | x_seq = []
81 | y_seq = []
82 | i = len(x)
83 | j = len(y)
84 | mapper_y_to_x = []
85 | while i > 0 or j > 0:
86 | if trace_back[i, j] == 3:
87 | x_seq.append(x[i-1])
88 | y_seq.append(y[j-1])
89 | i = i-1
90 | j = j-1
91 | mapper_y_to_x.append((j, i))
92 | elif trace_back[i][j] == 1:
93 | x_seq.append('-')
94 | y_seq.append(y[j-1])
95 | j = j-1
96 | mapper_y_to_x.append((j, -1))
97 | elif trace_back[i][j] == 2:
98 | x_seq.append(x[i-1])
99 | y_seq.append('-')
100 | i = i-1
101 | elif trace_back[i][j] == 4:
102 | break
103 | mapper_y_to_x.reverse()
104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
105 |
106 |
107 | def get_mapper(x: str, y: str, tokenizer, max_len=77):
108 | x_seq = tokenizer.encode(x)
109 | y_seq = tokenizer.encode(y)
110 | score = ScoreParams(0, 1, -1)
111 | matrix, trace_back = global_align(x_seq, y_seq, score)
112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
113 | alphas = torch.ones(max_len)
114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
115 | mapper = torch.zeros(max_len, dtype=torch.int64)
116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
118 | return mapper, alphas
119 |
120 |
121 | def get_refinement_mapper(prompts, tokenizer, max_len=77):
122 | x_seq = prompts[0]
123 | mappers, alphas = [], []
124 | for i in range(1, len(prompts)):
125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
126 | mappers.append(mapper)
127 | alphas.append(alpha)
128 | return torch.stack(mappers), torch.stack(alphas)
129 |
130 |
131 | def get_word_inds(text: str, word_place: int, tokenizer):
132 | split_text = text.split(" ")
133 | if type(word_place) is str:
134 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
135 | elif type(word_place) is int:
136 | word_place = [word_place]
137 | out = []
138 | if len(word_place) > 0:
139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
140 | cur_len, ptr = 0, 0
141 |
142 | for i in range(len(words_encode)):
143 | cur_len += len(words_encode[i])
144 | if ptr in word_place:
145 | out.append(i + 1)
146 | if cur_len >= len(split_text[ptr]):
147 | ptr += 1
148 | cur_len = 0
149 | return np.array(out)
150 |
151 |
152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
153 | words_x = x.split(' ')
154 | words_y = y.split(' ')
155 | if len(words_x) != len(words_y):
156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
161 | mapper = np.zeros((max_len, max_len))
162 | i = j = 0
163 | cur_inds = 0
164 | while i < max_len and j < max_len:
165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
167 | if len(inds_source_) == len(inds_target_):
168 | mapper[inds_source_, inds_target_] = 1
169 | else:
170 | ratio = 1 / len(inds_target_)
171 | for i_t in inds_target_:
172 | mapper[inds_source_, i_t] = ratio
173 | cur_inds += 1
174 | i += len(inds_source_)
175 | j += len(inds_target_)
176 | elif cur_inds < len(inds_source):
177 | mapper[i, j] = 1
178 | i += 1
179 | j += 1
180 | else:
181 | mapper[j, j] = 1
182 | i += 1
183 | j += 1
184 |
185 | return torch.from_numpy(mapper).float()
186 |
187 |
188 |
189 | def get_replacement_mapper(prompts, tokenizer, max_len=77):
190 | x_seq = prompts[0]
191 | mappers = []
192 | for i in range(1, len(prompts)):
193 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
194 | mappers.append(mapper)
195 | return torch.stack(mappers)
196 |
197 |
--------------------------------------------------------------------------------