├── .gitignore
├── ControlNeXt-SD1.5-Training
├── README.md
├── examples
│ ├── conditioning_image_1.png
│ └── conditioning_image_2.png
├── models
│ ├── controlnext.py
│ └── unet.py
├── pipeline
│ └── pipeline_controlnext.py
├── run_controlnext.py
├── scripts.sh
└── train_controlnext.py
├── ControlNeXt-SD1.5
├── README.md
├── examples
│ ├── deepfashion_caption
│ │ ├── condition_0.png
│ │ ├── condition_1.png
│ │ ├── eval_img
│ │ │ ├── chinese_style.jpg
│ │ │ ├── image_0.jpg
│ │ │ ├── warrior_bad.jpg
│ │ │ └── warrior_good.jpg
│ │ └── script.sh
│ ├── deepfashion_multiview
│ │ ├── condition_0.jpg
│ │ ├── condition_1.jpg
│ │ ├── condition_2.jpg
│ │ ├── eval_img
│ │ │ ├── Anythingv3.jpg
│ │ │ ├── Anythingv3_fischl.jpg
│ │ │ ├── Anythingv3_lisa.jpg
│ │ │ └── DreamShaper.jpg
│ │ └── script.sh
│ ├── deepfashoin_mask
│ │ ├── condition_0.png
│ │ ├── eval_img
│ │ │ └── image_0.jpg
│ │ └── script.sh
│ └── vidit_depth
│ │ ├── condition_0.png
│ │ ├── eval_img
│ │ └── image_0.jpg
│ │ └── script.sh
├── models
│ ├── controlnet.py
│ ├── pipeline_controlnext.py
│ └── unet.py
└── run_controlnext.py
├── ControlNeXt-SD3
└── README.md
├── ControlNeXt-SDXL-Training
├── README.md
├── examples
│ └── vidit_depth
│ │ ├── condition_0.png
│ │ └── train.sh
├── models
│ ├── controlnet.py
│ └── unet.py
├── pipeline
│ └── pipeline_controlnext.py
├── requirements.txt
├── train_controlnext.py
└── utils
│ ├── preprocess.py
│ ├── tools.py
│ └── utils.py
├── ControlNeXt-SDXL
├── README.md
├── examples
│ ├── anime_canny
│ │ ├── condition_0.jpg
│ │ ├── eval_img
│ │ │ ├── AAM.jpg
│ │ │ └── NetaXLV2.jpg
│ │ ├── image_0.jpg
│ │ ├── run.sh
│ │ └── run_with_pp.sh
│ ├── demo
│ │ ├── demo1.jpg
│ │ ├── demo2.jpg
│ │ ├── demo3.jpg
│ │ ├── demo4.jpg
│ │ └── demo5.jpg
│ └── vidit_depth
│ │ ├── condition_0.png
│ │ ├── eval_img
│ │ ├── StableDiffusionXL.jpg
│ │ └── StableDiffusionXL_GlassSculpturesLora.jpg
│ │ └── run.sh
├── models
│ ├── controlnet.py
│ └── unet.py
├── pipeline
│ └── pipeline_controlnext.py
├── requirements.txt
├── run_controlnext.py
└── utils
│ ├── preprocess.py
│ ├── tools.py
│ └── utils.py
├── ControlNeXt-SVD-v2-Training
├── README.md
├── deepspeed.yaml
├── meta_info_example
│ ├── meta_info.json
│ └── meta_info
│ │ └── 1.json
├── models
│ ├── controlnext_vid_svd.py
│ └── unet_spatio_temporal_condition_controlnext.py
├── pipeline
│ └── pipeline_stable_video_diffusion_controlnext.py
├── requirements.txt
├── script.sh
├── train_svd.py
└── utils
│ ├── dataset.py
│ ├── extract_learned_paras.py
│ ├── extract_vid2img.py
│ ├── img_dataset.py
│ ├── pkl_dataset.py
│ ├── scheduling_euler_discrete_karras_fix.py
│ ├── ubc_dataset.py
│ ├── unwrap_deepspeed.py
│ ├── util.py
│ └── vid_dataset.py
├── ControlNeXt-SVD-v2
├── README.md
├── dwpose
│ ├── __init__.py
│ ├── dwpose_detector.py
│ ├── onnxdet.py
│ ├── onnxpose.py
│ ├── preprocess.py
│ ├── util.py
│ └── wholebody.py
├── examples
│ ├── demos
│ │ ├── 01-1.mp4
│ │ ├── 02-1.mp4
│ │ ├── 03-1.mp4
│ │ └── 04-1.mp4
│ ├── facefusion
│ │ └── facefusion.jpg
│ ├── ref_imgs
│ │ ├── 01.jpeg
│ │ ├── 02.jpeg
│ │ ├── 03.jpeg
│ │ └── 04.jpeg
│ └── video
│ │ ├── 01.mp4
│ │ └── 02.mp4
├── models
│ ├── controlnext_vid_svd.py
│ └── unet_spatio_temporal_condition_controlnext.py
├── pipeline
│ └── pipeline_stable_video_diffusion_controlnext.py
├── run_controlnext.py
├── script.sh
└── utils
│ ├── pre_process.py
│ └── scheduling_euler_discrete_karras_fix.py
├── ControlNeXt-SVD
├── README.md
├── examples
│ ├── facefusion
│ │ └── facefusion.jpg
│ ├── pose
│ │ └── pose.mp4
│ └── ref_imgs
│ │ ├── spiderman.jpg
│ │ └── tiktok.png
├── models
│ ├── controlnext_vid_svd.py
│ └── unet_spatio_temporal_condition_controlnext.py
├── outputs
│ ├── chair
│ │ └── chair.mp4
│ ├── collected
│ │ ├── demo.jpg
│ │ ├── demo.mp4
│ │ └── out2.mp4
│ ├── spiderman
│ │ └── spiderman.mp4
│ ├── star
│ │ └── star.mp4
│ └── tiktok
│ │ └── tiktok.mp4
├── pipeline
│ └── pipeline_stable_video_diffusion_controlnext.py
├── run_controlnext.py
├── script.sh
└── utils
│ └── scheduling_euler_discrete_karras_fix.py
├── LICENSE
├── README.md
├── compress_image.py
├── experiences.md
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5-Training/README.md:
--------------------------------------------------------------------------------
1 | # 🌀 ControlNeXt-SD1.5
2 |
3 |
4 | This is the training script for our ControlNeXt model, based on Stable Diffusion 1.5.
5 |
6 | Our training and inference code has undergone some updates compared to the original version. Please refer to this version as the standard.
7 |
8 | We provide an example using an open dataset, where our method achieves convergence in just a thousand training steps.
9 |
10 | ## Train
11 |
12 |
13 | ```
14 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --main_process_port 1234 train_controlnext.py \
15 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
16 | --output_dir="checkpoints" \
17 | --dataset_name=fusing/fill50k \
18 | --resolution=512 \
19 | --learning_rate=1e-5 \
20 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \
21 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
22 | --checkpoints_total_limit 3 \
23 | --checkpointing_steps 400 \
24 | --validation_steps 400 \
25 | --num_train_epochs 4 \
26 | --train_batch_size=6 \
27 | --controlnext_scale 0.35 \
28 | --save_load_weights_increaments
29 | ```
30 |
31 | > --controlnext_scale: Set between [0, 1]; controls the strength of ControlNeXt. A larger value indicates stronger control. For tasks requiring dense conditional controls, such as depth, setting it larger (such as 1.) will provide better control. Increasing this number will lead to faster convergence and stronger control, but it can sometimes overly influence the final generation.
32 |
33 |
34 | > --save_load_weights_increments: Choose whether to save the trainable parameters directly or just the weight increments, i.e., $W_{finetune} - W_{pretrained}$. This is useful when adapting to various backbones.
35 |
36 | ## Inference
37 |
38 |
39 | ```
40 | python run_controlnext.py \
41 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
42 | --output_dir="test" \
43 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \
44 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
45 | --controlnet_model_name_or_path checkpoints/checkpoint-1400/controlnext.bin \
46 | --unet_model_name_or_path checkpoints/checkpoint-1200/unet.bin \
47 | --controlnext_scale 0.35
48 | ```
49 |
50 | ```
51 | python run_controlnext.py \
52 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
53 | --output_dir="test" \
54 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \
55 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
56 | --controlnet_model_name_or_path checkpoints/checkpoint-800/controlnext.bin \
57 | --unet_model_name_or_path checkpoints/checkpoint-1200/unet_weight_increasements.bin \
58 | --controlnext_scale 0.35 \
59 | --save_load_weights_increaments
60 | ```
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5-Training/examples/conditioning_image_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5-Training/examples/conditioning_image_1.png
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5-Training/examples/conditioning_image_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5-Training/examples/conditioning_image_2.png
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5-Training/models/controlnext.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from diffusers.configuration_utils import ConfigMixin, register_to_config
7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
8 | from diffusers.models.modeling_utils import ModelMixin
9 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D
10 |
11 |
12 | class ControlNeXtModel(ModelMixin, ConfigMixin):
13 | _supports_gradient_checkpointing = True
14 |
15 | @register_to_config
16 | def __init__(
17 | self,
18 | time_embed_dim = 256,
19 | in_channels = [128, 128],
20 | out_channels = [128, 256],
21 | groups = [4, 8],
22 | controlnext_scale=1.
23 | ):
24 | super().__init__()
25 |
26 | self.time_proj = Timesteps(128, True, downscale_freq_shift=0)
27 | self.time_embedding = TimestepEmbedding(128, time_embed_dim)
28 | self.embedding = nn.Sequential(
29 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
30 | nn.GroupNorm(2, 64),
31 | nn.ReLU(),
32 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
33 | nn.GroupNorm(2, 64),
34 | nn.ReLU(),
35 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
36 | nn.GroupNorm(2, 128),
37 | nn.ReLU(),
38 | )
39 |
40 | self.down_res = nn.ModuleList()
41 | self.down_sample = nn.ModuleList()
42 | for i in range(len(in_channels)):
43 | self.down_res.append(
44 | ResnetBlock2D(
45 | in_channels=in_channels[i],
46 | out_channels=out_channels[i],
47 | temb_channels=time_embed_dim,
48 | groups=groups[i]
49 | ),
50 | )
51 | self.down_sample.append(
52 | Downsample2D(
53 | out_channels[i],
54 | use_conv=True,
55 | out_channels=out_channels[i],
56 | padding=1,
57 | name="op",
58 | )
59 | )
60 |
61 | self.mid_convs = nn.ModuleList()
62 | self.mid_convs.append(nn.Sequential(
63 | nn.Conv2d(
64 | in_channels=out_channels[-1],
65 | out_channels=out_channels[-1],
66 | kernel_size=3,
67 | stride=1,
68 | padding=1
69 | ),
70 | nn.ReLU(),
71 | nn.GroupNorm(8, out_channels[-1]),
72 | nn.Conv2d(
73 | in_channels=out_channels[-1],
74 | out_channels=out_channels[-1],
75 | kernel_size=3,
76 | stride=1,
77 | padding=1
78 | ),
79 | nn.GroupNorm(8, out_channels[-1]),
80 | ))
81 | self.mid_convs.append(
82 | nn.Conv2d(
83 | in_channels=out_channels[-1],
84 | out_channels=320,
85 | kernel_size=1,
86 | stride=1,
87 | ))
88 |
89 | self.scale = controlnext_scale
90 |
91 | def forward(
92 | self,
93 | sample: torch.FloatTensor,
94 | timestep: Union[torch.Tensor, float, int],
95 | ):
96 |
97 | timesteps = timestep
98 | if not torch.is_tensor(timesteps):
99 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
100 | # This would be a good case for the `match` statement (Python 3.10+)
101 | is_mps = sample.device.type == "mps"
102 | if isinstance(timestep, float):
103 | dtype = torch.float32 if is_mps else torch.float64
104 | else:
105 | dtype = torch.int32 if is_mps else torch.int64
106 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
107 | elif len(timesteps.shape) == 0:
108 | timesteps = timesteps[None].to(sample.device)
109 |
110 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
111 | batch_size = sample.shape[0]
112 | timesteps = timesteps.expand(batch_size)
113 |
114 | t_emb = self.time_proj(timesteps)
115 |
116 | # `Timesteps` does not contain any weights and will always return f32 tensors
117 | # but time_embedding might actually be running in fp16. so we need to cast here.
118 | # there might be better ways to encapsulate this.
119 | t_emb = t_emb.to(dtype=sample.dtype)
120 |
121 | emb = self.time_embedding(t_emb)
122 |
123 | sample = self.embedding(sample)
124 |
125 | for res, downsample in zip(self.down_res, self.down_sample):
126 | sample = res(sample, emb)
127 | sample = downsample(sample, emb)
128 |
129 | sample = self.mid_convs[0](sample) + sample
130 | sample = self.mid_convs[1](sample)
131 |
132 | return {
133 | 'output': sample,
134 | 'scale': self.scale,
135 | }
136 |
137 |
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5-Training/scripts.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --main_process_port 1234 train_controlnext.py \
2 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
3 | --output_dir="checkpoints" \
4 | --dataset_name=fusing/fill50k \
5 | --resolution=512 \
6 | --learning_rate=1e-5 \
7 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \
8 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
9 | --checkpoints_total_limit 3 \
10 | --checkpointing_steps 400 \
11 | --validation_steps 400 \
12 | --num_train_epochs 4 \
13 | --train_batch_size=6 \
14 | --controlnext_scale 0.35 \
15 | --save_load_weights_increaments
16 |
17 |
18 |
19 |
20 | CUDA_VISIBLE_DEVICES=4 python run_controlnext.py \
21 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
22 | --output_dir="test" \
23 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \
24 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
25 | --controlnet_model_name_or_path checkpoints/checkpoint-1400/controlnext.bin \
26 | --unet_model_name_or_path checkpoints/checkpoint-1200/unet.bin \
27 | --controlnext_scale 0.35
28 |
29 |
30 |
31 | CUDA_VISIBLE_DEVICES=5 python run_controlnext.py \
32 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
33 | --output_dir="test" \
34 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \
35 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
36 | --controlnet_model_name_or_path checkpoints/checkpoint-400/controlnext.bin \
37 | --unet_model_name_or_path checkpoints/checkpoint-400/unet_weight_increasements.bin \
38 | --controlnext_scale 0.35 \
39 | --save_load_weights_increaments
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/README.md:
--------------------------------------------------------------------------------
1 | # 🌀 ControlNeXt-SD1.5
2 |
3 | `Please refer to SDXL and SVD for our newly updated version !`
4 |
5 | This is our implementation of ControlNeXt based on [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
6 |
7 | > Please refer to [Examples](#examples) for further intuitive details.\
8 | > Please refer to [Inference](#inference) for more details regarding installation and inference.\
9 | > Please refer to [Model Zoo](#model-zoo) for more other our trained models.
10 |
11 | Our method demonstrates the advantages listed below:
12 |
13 | - **Few trainable parameters**: only requiring **5~30M** trainable parameters (occupying 20~80 MB of memory).
14 | - **Fast training speed**: no sudden convergence.
15 | - **Efficient**: no need for additional branch; only a lightweight module is required.
16 | - **Compatibility**: can serve as a **plug-and-play** lightweight module and can be combined with other LoRA weights.
17 |
18 | # Examples
19 |
20 | The demo examples are generated using the ControlNeXt trained on deepfashion_multiview dataset with utilizing [DreamShaper](https://huggingface.co/Lykon/DreamShaper) as the base model. Our method demonstrates excellent compatibility and can be applied to most other models based on sd1.5 architecture and LoRA. And you can retrain your own model for better performance.
21 |
22 | ## BaseModel
23 |
24 | Our model can be applied to various base models without the need for futher training as a plug-and-play module.
25 |
26 | > 📌 Of course, you can retrain your owm model, especially for complex tasks and to achieve better performance.
27 |
28 | - [DreamShaper](https://huggingface.co/Lykon/DreamShaper)
29 |
30 |
31 |
32 |
33 |
34 | - [Anything-v3.0](https://huggingface.co/admruul/anything-v3.0)
35 |
36 |
37 |
38 |
39 |
40 | ## LoRA
41 |
42 | Our model can also be directly combined with other publicly available LoRA weights.
43 |
44 | - [Lisa](https://civitai.com/articles/4584)
45 |
46 |
47 |
48 |
49 |
50 | - [Fischl](https://civitai.com/articles/4584)
51 |
52 |
53 |
54 |
55 |
56 | - [Chinese Style](https://civitai.com/models/12597/moxin)
57 |
58 |
59 |
60 |
61 |
62 | ## Stable Generation
63 |
64 | Sometimes, it is difficult to generate good results, and you have to repeatedly adjust your prompt to achieve satisfactory outcomes. However, this process is challenging because prompts are very abstract.
65 |
66 | Our method can serve as a plug-and-play module for stable generation.
67 |
68 | - Without ControlNeXt (Use original [SD](https://huggingface.co/runwayml/stable-diffusion-v1-5) as base model)
69 |
70 |
71 |
72 |
73 | - With ControlNeXt (Use original [SD](https://huggingface.co/runwayml/stable-diffusion-v1-5) as base model)
74 |
75 |
76 |
77 |
78 | # Inference
79 |
80 | 1. Clone our repository
81 | 2. `cd ControlNeXt-SD1.5`
82 | 3. Download the pretrained weight into `pretrained/` from [here](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SD1.5). (Recommended to use `deepfashion_multiview` and `deepfashion_caption`)
83 | 4. (Optional) Download the LoRA weight, such as [Genshin](https://civitai.com/models/362091/sd15all-characters-genshin-impact-124-characters-124). And put them under `lora/`
84 | 5. Run the scipt
85 |
86 | ```python
87 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
88 | --pretrained_model_name_or_path="admruul/anything-v3.0" \
89 | --output_dir="examples/deepfashion_multiview" \
90 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \
91 | --validation_prompt "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" \
92 | --negative_prompt "PBH" "PBH"\
93 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \
94 | (Optional)--lora_path lora/yuanshen/genshin_124.safetensors \
95 | (Optional, less generality, stricter control)--unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors
96 | ```
97 |
98 | > --pretrained_model_name_or_path : pretrained base model, we try on [DreamShaper](https://huggingface.co/Lykon/DreamShaper), [Anything-v3.0](https://huggingface.co/admruul/anything-v3.0), [ori SD](https://huggingface.co/runwayml/stable-diffusion-v1-5) \
99 | > --controlnet_model_name_or_path : the model path of controlnet (a light weight module) \
100 | > --lora_path : downloaded other LoRA weight \
101 | > --unet_model_name_or_path : the model path of a subset of unet parameters
102 |
103 | > 📌 Pose-based generation is a relative simple task. And in most cases, it is enough to just load the control module by `--controlnet_model_name_or_path`. However, sometime the task is hard so it is need to select some subset of the original unet parameters to fit the task (Can be seen as another kind of LoRA). \
104 | > More parameters mean weaker generality, so you can make your own tradeoff. Or directly train your own models based on your own data. The training is also fast.
105 |
106 | # Model Zoo
107 |
108 | We also provide some additional examples, but these are just for demonstration purposes. The training data is relatively small and of low quality.
109 |
110 | - vidit_depth
111 |
112 |
113 |
114 |
115 | - mask
116 |
117 |
118 |
119 |
120 | # TODO
121 |
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_caption/condition_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/condition_0.png
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_caption/condition_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/condition_1.png
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/chinese_style.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/chinese_style.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/image_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/image_0.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/warrior_bad.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/warrior_bad.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/warrior_good.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/warrior_good.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_caption/script.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
2 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
3 | --output_dir="examples/deepfashion_caption" \
4 | --validation_image "examples/deepfashion_caption/condition_1.png" "examples/deepfashion_caption/condition_0.png" \
5 | --validation_prompt "a woman wearing a black shirt and black leather skirt" "levi's women's white graphic t - shirt" \
6 | --controlnet_model_name_or_path pretrained/deepfashion_caption/controlnet.safetensors \
7 | --unet_model_name_or_path pretrained/deepfashion_caption/unet.safetensors
8 |
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_0.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_1.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_2.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3_fischl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3_fischl.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3_lisa.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3_lisa.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/DreamShaper.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/DreamShaper.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashion_multiview/script.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
2 | --pretrained_model_name_or_path="admruul/anything-v3.0" \
3 | --output_dir="examples/deepfashion_multiview" \
4 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \
5 | --validation_prompt "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" \
6 | --negative_prompt "PBH" "PBH"\
7 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \
8 | --lora_path lora/yuanshen/genshin_124.safetensors \
9 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors
10 |
11 |
12 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
13 | --pretrained_model_name_or_path="admruul/anything-v3.0" \
14 | --output_dir="examples/deepfashion_multiview" \
15 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \
16 | --validation_prompt "1boy, braid, single_earring, short_sleeves, white_scarf, black_gloves, alternate_costume, standing, black_shirt" "1boy, braid, single_earring, short_sleeves, white_scarf, black_gloves, alternate_costume, standing, black_shirt" \
17 | --negative_prompt "PBH" "PBH"\
18 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \
19 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors
20 |
21 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
22 | --pretrained_model_name_or_path="admruul/anything-v3.0" \
23 | --output_dir="examples/deepfashion_multiview" \
24 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \
25 | --validation_prompt "lisa_\(a_sobriquet_under_shade\)_\(genshin_impact\), lisa_\(genshin_impact\), 1girl, green_headwear, official_alternate_costume, cleavage, twin_braids, hair_flower, vision_\(genshin_impact\), large_breasts, thighlet, puffy_long_sleeves, purple_rose, beret" "lisa_\(a_sobriquet_under_shade\)_\(genshin_impact\), lisa_\(genshin_impact\), 1girl, green_headwear, official_alternate_costume, cleavage, twin_braids, hair_flower, vision_\(genshin_impact\), large_breasts, thighlet, puffy_long_sleeves, purple_rose, beret" \
26 | --negative_prompt "PBH" "PBH"\
27 | --lora_path lora/yuanshen/genshin_124.safetensors \
28 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \
29 |
30 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
31 | --pretrained_model_name_or_path="admruul/anything-v3.0" \
32 | --output_dir="examples/deepfashion_multiview" \
33 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \
34 | --validation_prompt "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" \
35 | --negative_prompt "PBH" "PBH"\
36 | --lora_path lora/yuanshen/genshin_124.safetensors \
37 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \
38 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors
39 |
40 |
41 | # base generation
42 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
43 | --pretrained_model_name_or_path="Lykon/DreamShaper" \
44 | --output_dir="examples/deepfashion_multiview" \
45 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg"\
46 | --validation_prompt "a woman in white shorts and a tank top" "a woman wearing a black shirt and black leather skirt" \
47 | --negative_prompt "PBH" "PBH" \
48 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \
49 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors
50 |
51 |
52 | # Combine with LoRA
53 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
54 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
55 | --output_dir="examples/deepfashion_multiview" \
56 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg"\
57 | --negative_prompt "PBH" "PBH" \
58 | --validation_prompt "c1bo, a woman, Armor, weapon, beautiful" "c1bo, a man, fight" \
59 | --lora_path lora/c1bo/cyborg_v_2_SD15.safetensors \
60 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \
61 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors
62 |
63 | # Combine with LoRA, without our control
64 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
65 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
66 | --output_dir="examples/deepfashion_multiview" \
67 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \
68 | --negative_prompt "PBH" "PBH" \
69 | --validation_prompt "c1bo, a woman, Armor, weapon, beautiful" "c1bo, a man, fight" \
70 | --lora_path lora/c1bo/cyborg_v_2_SD15.safetensors
71 |
72 |
73 |
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashoin_mask/condition_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashoin_mask/condition_0.png
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashoin_mask/eval_img/image_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashoin_mask/eval_img/image_0.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/deepfashoin_mask/script.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
2 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
3 | --output_dir="examples/deepfashoin_mask" \
4 | --validation_image "examples/deepfashoin_mask/condition_0.png" \
5 | --validation_prompt "a woman in white shorts and a tank top" \
6 | --controlnet_model_name_or_path pretrained/deepfashoin_mask/controlnet.safetensors \
7 | --unet_model_name_or_path pretrained/deepfashoin_mask/unet.safetensors
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/vidit_depth/condition_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/vidit_depth/condition_0.png
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/vidit_depth/eval_img/image_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/vidit_depth/eval_img/image_0.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SD1.5/examples/vidit_depth/script.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
2 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
3 | --output_dir="examples/vidit_depth" \
4 | --validation_image "examples/vidit_depth/condition_0.png" \
5 | --validation_prompt "a wooden bridge in the middle of a field" \
6 | --controlnet_model_name_or_path pretrained/vidit_depth/controlnet.safetensors \
7 | --unet_model_name_or_path pretrained/vidit_depth/unet.safetensors
--------------------------------------------------------------------------------
/ControlNeXt-SD3/README.md:
--------------------------------------------------------------------------------
1 | # ControlNeXt-SD3
2 |
3 | We regret to inform you that our model of the Stable Diffusion 3 is trained with protected and private data and code, and therefore cannot be released. However, we are considering releasing another model based on the DiT structure in the future.
4 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL-Training/README.md:
--------------------------------------------------------------------------------
1 | # 🌀 ControlNeXt-SDXL
2 |
3 | This is our **training** demo of ControlNeXt based on [Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0).
4 |
5 | Hardware requirement: A single GPU with at least 20GB memory.
6 |
7 | ## Quick Start
8 |
9 | Clone the repository:
10 |
11 | ```bash
12 | git clone https://github.com/dvlab-research/ControlNeXt
13 | cd ControlNeXt/ControlNeXt-SDXL-Training
14 | ```
15 |
16 | Install the required packages:
17 |
18 | ```bash
19 | pip install -r requirements.txt
20 | ```
21 |
22 | Run the training script:
23 |
24 | ```bash
25 | bash examples/vidit_depth/train.sh
26 | ```
27 |
28 | The output will be saved in `train/example`.
29 |
30 | ## Usage
31 |
32 | We recommend to only save & load the weights difference of the UNet's trainable parameters, i.e., $\Delta W = W_{finetune} - W_{pretrained}$, rather than the actual weight.
33 | This is useful when adapting to various base models since the weights difference is model-agnostic.
34 |
35 | ```python
36 | accelerate launch train_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \
37 | --pretrained_vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
38 | --variant fp16 \
39 | --use_safetensors \
40 | --output_dir "train/example" \
41 | --logging_dir "logs" \
42 | --resolution 1024 \
43 | --gradient_checkpointing \
44 | --set_grads_to_none \
45 | --proportion_empty_prompts 0.2 \
46 | --controlnet_scale_factor 1.0 \ # the strength of the controlnet output. For depth, we recommend 1.0, and for canny, we recommend 0.35
47 | --save_weights_increaments \
48 | --mixed_precision fp16 \
49 | --enable_xformers_memory_efficient_attention \
50 | --dataset_name "Nahrawy/VIDIT-Depth-ControlNet" \
51 | --image_column "image" \
52 | --conditioning_image_column "depth_map" \
53 | --caption_column "caption" \
54 | --validation_prompt "a stone tower on a rocky island" \
55 | --validation_image "examples/vidit_depth/condition_0.png"
56 | ```
57 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL-Training/examples/vidit_depth/condition_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL-Training/examples/vidit_depth/condition_0.png
--------------------------------------------------------------------------------
/ControlNeXt-SDXL-Training/examples/vidit_depth/train.sh:
--------------------------------------------------------------------------------
1 | accelerate launch train_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \
2 | --pretrained_vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
3 | --variant fp16 \
4 | --use_safetensors \
5 | --output_dir "train/example" \
6 | --logging_dir "logs" \
7 | --resolution 1024 \
8 | --gradient_checkpointing \
9 | --set_grads_to_none \
10 | --proportion_empty_prompts 0.2 \
11 | --controlnet_scale_factor 1.0 \
12 | --mixed_precision fp16 \
13 | --enable_xformers_memory_efficient_attention \
14 | --dataset_name "Nahrawy/VIDIT-Depth-ControlNet" \
15 | --image_column "image" \
16 | --conditioning_image_column "depth_map" \
17 | --caption_column "caption" \
18 | --validation_prompt "a stone tower on a rocky island" \
19 | --validation_image "examples/vidit_depth/condition_0.png"
--------------------------------------------------------------------------------
/ControlNeXt-SDXL-Training/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | accelerate
4 | opencv-python
5 | pillow
6 | numpy
7 | transformers
8 | diffusers
9 | safetensors
10 | peft
11 | xformers
12 | huggingface-hub
13 | datasets
14 | einops
--------------------------------------------------------------------------------
/ControlNeXt-SDXL-Training/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from PIL import Image
4 |
5 |
6 | def get_extractor(extractor_name):
7 | if extractor_name is None:
8 | return None
9 | if extractor_name not in EXTRACTORS:
10 | raise ValueError(f"Extractor {extractor_name} is not supported.")
11 | return EXTRACTORS[extractor_name]
12 |
13 |
14 | def canny_extractor(image: Image.Image, threshold1=None, threshold2=None) -> Image.Image:
15 | image = np.array(image)
16 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
17 | v = np.median(gray)
18 |
19 | sigma = 0.33
20 | threshold1 = threshold1 or int(max(0, (1.0 - sigma) * v))
21 | threshold2 = threshold2 or int(min(255, (1.0 + sigma) * v))
22 |
23 | edges = cv2.Canny(gray, threshold1, threshold2)
24 | edges = Image.fromarray(edges).convert("RGB")
25 | return edges
26 |
27 |
28 | def depth_extractor(image: Image.Image):
29 | raise NotImplementedError("Depth extractor is not implemented yet.")
30 |
31 |
32 | def pose_extractor(image: Image.Image):
33 | raise NotImplementedError("Pose extractor is not implemented yet.")
34 |
35 |
36 | EXTRACTORS = {
37 | "canny": canny_extractor,
38 | }
39 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL-Training/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import torch
4 | from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel
5 | from safetensors.torch import load_file
6 | from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline
7 | from models.unet import UNet2DConditionModel
8 | from models.controlnet import ControlNetModel
9 | from . import utils
10 |
11 | UNET_CONFIG = {
12 | "act_fn": "silu",
13 | "addition_embed_type": "text_time",
14 | "addition_embed_type_num_heads": 64,
15 | "addition_time_embed_dim": 256,
16 | "attention_head_dim": [
17 | 5,
18 | 10,
19 | 20
20 | ],
21 | "block_out_channels": [
22 | 320,
23 | 640,
24 | 1280
25 | ],
26 | "center_input_sample": False,
27 | "class_embed_type": None,
28 | "class_embeddings_concat": False,
29 | "conv_in_kernel": 3,
30 | "conv_out_kernel": 3,
31 | "cross_attention_dim": 2048,
32 | "cross_attention_norm": None,
33 | "down_block_types": [
34 | "DownBlock2D",
35 | "CrossAttnDownBlock2D",
36 | "CrossAttnDownBlock2D"
37 | ],
38 | "downsample_padding": 1,
39 | "dual_cross_attention": False,
40 | "encoder_hid_dim": None,
41 | "encoder_hid_dim_type": None,
42 | "flip_sin_to_cos": True,
43 | "freq_shift": 0,
44 | "in_channels": 4,
45 | "layers_per_block": 2,
46 | "mid_block_only_cross_attention": None,
47 | "mid_block_scale_factor": 1,
48 | "mid_block_type": "UNetMidBlock2DCrossAttn",
49 | "norm_eps": 1e-05,
50 | "norm_num_groups": 32,
51 | "num_attention_heads": None,
52 | "num_class_embeds": None,
53 | "only_cross_attention": False,
54 | "out_channels": 4,
55 | "projection_class_embeddings_input_dim": 2816,
56 | "resnet_out_scale_factor": 1.0,
57 | "resnet_skip_time_act": False,
58 | "resnet_time_scale_shift": "default",
59 | "sample_size": 128,
60 | "time_cond_proj_dim": None,
61 | "time_embedding_act_fn": None,
62 | "time_embedding_dim": None,
63 | "time_embedding_type": "positional",
64 | "timestep_post_act": None,
65 | "transformer_layers_per_block": [
66 | 1,
67 | 2,
68 | 10
69 | ],
70 | "up_block_types": [
71 | "CrossAttnUpBlock2D",
72 | "CrossAttnUpBlock2D",
73 | "UpBlock2D"
74 | ],
75 | "upcast_attention": None,
76 | "use_linear_projection": True
77 | }
78 |
79 | CONTROLNET_CONFIG = {
80 | 'in_channels': [128, 128],
81 | 'out_channels': [128, 256],
82 | 'groups': [4, 8],
83 | 'time_embed_dim': 256,
84 | 'final_out_channels': 320,
85 | '_use_default_values': ['time_embed_dim', 'groups', 'in_channels', 'final_out_channels', 'out_channels']
86 | }
87 |
88 |
89 | def get_pipeline(
90 | pretrained_model_name_or_path,
91 | unet_model_name_or_path,
92 | controlnet_model_name_or_path,
93 | vae_model_name_or_path=None,
94 | lora_path=None,
95 | load_weight_increasement=False,
96 | enable_xformers_memory_efficient_attention=False,
97 | revision=None,
98 | variant=None,
99 | hf_cache_dir=None,
100 | use_safetensors=True,
101 | device=None,
102 | ):
103 | pipeline_init_kwargs = {}
104 |
105 | print(f"loading unet from {pretrained_model_name_or_path}")
106 | if os.path.isfile(pretrained_model_name_or_path):
107 | # load unet from local checkpoint
108 | unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path)
109 | unet_sd = utils.extract_unet_state_dict(unet_sd)
110 | unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
111 | unet = UNet2DConditionModel.from_config(UNET_CONFIG)
112 | unet.load_state_dict(unet_sd, strict=True)
113 | else:
114 | unet = UNet2DConditionModel.from_pretrained(
115 | pretrained_model_name_or_path,
116 | cache_dir=hf_cache_dir,
117 | variant=variant,
118 | torch_dtype=torch.float16,
119 | use_safetensors=use_safetensors,
120 | subfolder="unet",
121 | )
122 | unet = unet.to(dtype=torch.float16)
123 | pipeline_init_kwargs["unet"] = unet
124 |
125 | if vae_model_name_or_path is not None:
126 | print(f"loading vae from {vae_model_name_or_path}")
127 | vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device)
128 | pipeline_init_kwargs["vae"] = vae
129 |
130 | if controlnet_model_name_or_path is not None:
131 | pipeline_init_kwargs["controlnet"] = ControlNetModel.from_config(CONTROLNET_CONFIG).to(device, dtype=torch.float32) # init
132 |
133 | print(f"loading pipeline from {pretrained_model_name_or_path}")
134 | if os.path.isfile(pretrained_model_name_or_path):
135 | pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file(
136 | pretrained_model_name_or_path,
137 | use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"),
138 | local_files_only=True,
139 | cache_dir=hf_cache_dir,
140 | **pipeline_init_kwargs,
141 | )
142 | else:
143 | pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained(
144 | pretrained_model_name_or_path,
145 | revision=revision,
146 | variant=variant,
147 | use_safetensors=use_safetensors,
148 | cache_dir=hf_cache_dir,
149 | **pipeline_init_kwargs,
150 | )
151 |
152 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
153 | if unet_model_name_or_path is not None:
154 | print(f"loading controlnext unet from {unet_model_name_or_path}")
155 | pipeline.load_controlnext_unet_weights(
156 | unet_model_name_or_path,
157 | load_weight_increasement=load_weight_increasement,
158 | use_safetensors=True,
159 | torch_dtype=torch.float16,
160 | cache_dir=hf_cache_dir,
161 | )
162 | if controlnet_model_name_or_path is not None:
163 | print(f"loading controlnext controlnet from {controlnet_model_name_or_path}")
164 | pipeline.load_controlnext_controlnet_weights(
165 | controlnet_model_name_or_path,
166 | use_safetensors=True,
167 | torch_dtype=torch.float32,
168 | cache_dir=hf_cache_dir,
169 | )
170 | pipeline.set_progress_bar_config()
171 | pipeline = pipeline.to(device, dtype=torch.float16)
172 |
173 | if lora_path is not None:
174 | pipeline.load_lora_weights(lora_path)
175 | if enable_xformers_memory_efficient_attention:
176 | pipeline.enable_xformers_memory_efficient_attention()
177 |
178 | gc.collect()
179 | if str(device) == 'cuda' and torch.cuda.is_available():
180 | torch.cuda.empty_cache()
181 |
182 | return pipeline
183 |
184 |
185 | def get_scheduler(
186 | scheduler_name,
187 | scheduler_config,
188 | ):
189 | if scheduler_name == 'Euler A':
190 | from diffusers.schedulers import EulerAncestralDiscreteScheduler
191 | scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
192 | elif scheduler_name == 'UniPC':
193 | from diffusers.schedulers import UniPCMultistepScheduler
194 | scheduler = UniPCMultistepScheduler.from_config(scheduler_config)
195 | elif scheduler_name == 'Euler':
196 | from diffusers.schedulers import EulerDiscreteScheduler
197 | scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
198 | elif scheduler_name == 'DDIM':
199 | from diffusers.schedulers import DDIMScheduler
200 | scheduler = DDIMScheduler.from_config(scheduler_config)
201 | elif scheduler_name == 'DDPM':
202 | from diffusers.schedulers import DDPMScheduler
203 | scheduler = DDPMScheduler.from_config(scheduler_config)
204 | else:
205 | raise ValueError(f"Unknown scheduler: {scheduler_name}")
206 | return scheduler
207 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL-Training/utils/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple, Union, Optional
3 |
4 |
5 | def make_unet_conversion_map():
6 | unet_conversion_map_layer = []
7 |
8 | for i in range(3): # num_blocks is 3 in sdxl
9 | # loop over downblocks/upblocks
10 | for j in range(2):
11 | # loop over resnets/attentions for downblocks
12 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
13 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
14 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
15 |
16 | if i < 3:
17 | # no attention layers in down_blocks.3
18 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
19 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
20 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
21 |
22 | for j in range(3):
23 | # loop over resnets/attentions for upblocks
24 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
25 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
26 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
27 |
28 | # if i > 0: commentout for sdxl
29 | # no attention layers in up_blocks.0
30 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
31 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
32 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
33 |
34 | if i < 3:
35 | # no downsample in down_blocks.3
36 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
37 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
38 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
39 |
40 | # no upsample in up_blocks.3
41 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
42 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
43 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
44 |
45 | hf_mid_atn_prefix = "mid_block.attentions.0."
46 | sd_mid_atn_prefix = "middle_block.1."
47 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
48 |
49 | for j in range(2):
50 | hf_mid_res_prefix = f"mid_block.resnets.{j}."
51 | sd_mid_res_prefix = f"middle_block.{2*j}."
52 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
53 |
54 | unet_conversion_map_resnet = [
55 | # (stable-diffusion, HF Diffusers)
56 | ("in_layers.0.", "norm1."),
57 | ("in_layers.2.", "conv1."),
58 | ("out_layers.0.", "norm2."),
59 | ("out_layers.3.", "conv2."),
60 | ("emb_layers.1.", "time_emb_proj."),
61 | ("skip_connection.", "conv_shortcut."),
62 | ]
63 |
64 | unet_conversion_map = []
65 | for sd, hf in unet_conversion_map_layer:
66 | if "resnets" in hf:
67 | for sd_res, hf_res in unet_conversion_map_resnet:
68 | unet_conversion_map.append((sd + sd_res, hf + hf_res))
69 | else:
70 | unet_conversion_map.append((sd, hf))
71 |
72 | for j in range(2):
73 | hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
74 | sd_time_embed_prefix = f"time_embed.{j*2}."
75 | unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
76 |
77 | for j in range(2):
78 | hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
79 | sd_label_embed_prefix = f"label_emb.0.{j*2}."
80 | unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
81 |
82 | unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
83 | unet_conversion_map.append(("out.0.", "conv_norm_out."))
84 | unet_conversion_map.append(("out.2.", "conv_out."))
85 |
86 | return unet_conversion_map
87 |
88 |
89 | def convert_unet_state_dict(src_sd, conversion_map):
90 | converted_sd = {}
91 | for src_key, value in src_sd.items():
92 | src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
93 | while len(src_key_fragments) > 0:
94 | src_key_prefix = ".".join(src_key_fragments) + "."
95 | if src_key_prefix in conversion_map:
96 | converted_prefix = conversion_map[src_key_prefix]
97 | converted_key = converted_prefix + src_key[len(src_key_prefix):]
98 | converted_sd[converted_key] = value
99 | break
100 | src_key_fragments.pop(-1)
101 | assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
102 |
103 | return converted_sd
104 |
105 |
106 | def convert_sdxl_unet_state_dict_to_diffusers(sd):
107 | unet_conversion_map = make_unet_conversion_map()
108 |
109 | conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
110 | return convert_unet_state_dict(sd, conversion_dict)
111 |
112 |
113 | def extract_unet_state_dict(state_dict):
114 | unet_sd = {}
115 | UNET_KEY_PREFIX = "model.diffusion_model."
116 | for k, v in state_dict.items():
117 | if k.startswith(UNET_KEY_PREFIX):
118 | unet_sd[k[len(UNET_KEY_PREFIX):]] = v
119 | return unet_sd
120 |
121 |
122 | def log_model_info(model, name):
123 | sd = model.state_dict() if hasattr(model, "state_dict") else model
124 | print(
125 | f"{name}:",
126 | f" number of parameters: {sum(p.numel() for p in sd.values())}",
127 | f" dtype: {sd[next(iter(sd))].dtype}",
128 | sep='\n'
129 | )
130 |
131 |
132 | def around_reso(img_w, img_h, reso: Union[Tuple[int, int], int], divisible: Optional[int] = None, max_width=None, max_height=None) -> Tuple[int, int]:
133 | r"""
134 | w*h = reso*reso
135 | w/h = img_w/img_h
136 | => w = img_ar*h
137 | => img_ar*h^2 = reso
138 | => h = sqrt(reso / img_ar)
139 | """
140 | reso = reso if isinstance(reso, tuple) else (reso, reso)
141 | divisible = divisible or 1
142 | if img_w * img_h <= reso[0] * reso[1] and (not max_width or img_w <= max_width) and (not max_height or img_h <= max_height) and img_w % divisible == 0 and img_h % divisible == 0:
143 | return (img_w, img_h)
144 | img_ar = img_w / img_h
145 | around_h = math.sqrt(reso[0]*reso[1] / img_ar)
146 | around_w = img_ar * around_h // divisible * divisible
147 | if max_width and around_w > max_width:
148 | around_h = around_h * max_width // around_w
149 | around_w = max_width
150 | elif max_height and around_h > max_height:
151 | around_w = around_w * max_height // around_h
152 | around_h = max_height
153 | around_h = min(around_h, max_height) if max_height else around_h
154 | around_w = min(around_w, max_width) if max_width else around_w
155 | around_h = int(around_h // divisible * divisible)
156 | around_w = int(around_w // divisible * divisible)
157 | return (around_w, around_h)
158 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/README.md:
--------------------------------------------------------------------------------
1 | # 🌀 ControlNeXt-SDXL
2 |
3 | This is our implementation of ControlNeXt based on [Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0).
4 |
5 | > Please refer to [Examples](#examples) for further intuitive details.\
6 | > Please refer to [Inference](#inference) for more details regarding installation and inference.\
7 |
8 | Our method demonstrates the advantages listed below:
9 |
10 | - **Few trainable parameters**: only requiring **5~200M** trainable parameters.
11 | - **Fast training speed**: reduce sudden convergence.
12 | - **Efficient**: no need for additional brunch; only a lightweight module is required.
13 | - **Compatibility**: can serve as a **plug-and-play** lightweight module and can be combined with other LoRA weights.
14 |
15 | # Examples
16 |
17 | The demo examples are generated using the ControlNeXt trained on
18 |
19 | - (i) our vidit_depth dataset with utilizing [Stable Diffusion XL 1.0 Base](stabilityai/stable-diffusion-xl-base-1.0) as the base model.
20 | - (ii) our anime_canny dataset with utilizing [Neta Art XL 2.0](https://civitai.com/models/410737/neta-art-xl) as the base model.
21 |
22 | Our method demonstrates excellent compatibility and can be applied to most other models based on SDXL1.0 architecture and LoRA. And you can retrain your own model for better performance.
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | ## BaseModel
32 |
33 | Our model can be applied to various base models without the need for futher training as a plug-and-play module.
34 |
35 | > 📌 Of course, you can retrain your owm model, especially for complex tasks and to achieve better performance.
36 |
37 | - [Stable Diffusion XL 1.0 Base](stabilityai/stable-diffusion-xl-base-1.0)
38 |
39 |
40 |
41 |
42 |
43 | - [AAM XL](https://huggingface.co/Lykon/AAM_XL_AnimeMix)
44 |
45 |
46 |
47 |
48 |
49 | - [Neta XL V2](https://civitai.com/models/410737/neta-art-xl)
50 |
51 |
52 |
53 |
54 |
55 | ## LoRA
56 |
57 | Our model can also be directly combined with other publicly available LoRA weights.
58 |
59 | - [Glass Sculptures](https://civitai.com/models/11203/glass-sculptures?modelVersionId=177888)
60 |
61 |
62 |
63 |
64 |
65 | # Inference
66 |
67 | ## Quick Start
68 |
69 | Clone the repository:
70 |
71 | ```bash
72 | git clone https://github.com/dvlab-research/ControlNeXt
73 | cd ControlNeXt/ControlNeXt-SDXL
74 | ```
75 |
76 | Install the required packages:
77 |
78 | ```bash
79 | pip install -r requirements.txt
80 | ```
81 |
82 | (Optional) Download the LoRA weight, such as [Amiya (Arknights) Fresh Art Style](https://civitai.com/models/231598/amiya-arknights-fresh-art-style-xl-trained-with-6k-images). And put them under `lora/`.
83 |
84 | Run the example:
85 |
86 | ```bash
87 | bash examples/anime_canny/run.sh
88 | ```
89 |
90 | ## Usage
91 |
92 | ### Canny Condition
93 |
94 | ```python
95 | # examples/anime_canny/run.sh
96 | python run_controlnext.py --pretrained_model_name_or_path "neta-art/neta-xl-2.0" \
97 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \
98 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \
99 | --controlnet_scale 1.0 \ # controlnet scale factor used to adjust the strength of the control condition
100 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
101 | --validation_prompt "3d style, photorealistic style, 1girl, arknights, amiya (arknights), solo, white background, upper body, looking at viewer, blush, closed mouth, low ponytail, black jacket, hooded jacket, open jacket, hood down, blue neckwear" \
102 | --negative_prompt "worst quality, abstract, clumsy pose, deformed hand, fused fingers, extra digits, fewer digits, fewer fingers, extra fingers, extra arm, missing arm, extra leg, missing leg, signature, artist name, multi views, disfigured, ugly" \
103 | --validation_image "examples/anime_canny/condition_0.png" \ # input canny image
104 | --output_dir "examples/anime_canny" \
105 | --load_weight_increasement # load weight increasement
106 | ```
107 |
108 | We use a `controlnet_scale` factor to adjust the strength of the control condition.
109 |
110 | We recommend to only save & load the weights difference of the UNet's trainable parameters, i.e., $\Delta W = W_{finetune} - W_{pretrained}$, rather than the actual weight.
111 | This is useful when adapting to various base models since the weights difference is model-agnostic.
112 |
113 | ### Depth Condition
114 |
115 | ```python
116 | # examples/vidit_depth/run.sh
117 | python run_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \
118 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-vidit-depth" \
119 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-vidit-depth" \
120 | --controlnet_scale 1.0 \
121 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
122 | --validation_prompt "a diamond tower in the middle of a lava lake" \
123 | --validation_image "examples/vidit_depth/condition_0.png" \ # input depth image
124 | --output_dir "examples/vidit_depth" \
125 | --width 1024 \
126 | --height 1024 \
127 | --load_weight_increasement \
128 | --variant fp16
129 | ```
130 |
131 | ## Run with Image Processor
132 |
133 | We also provide a simple image processor to help you automatically convert the image to the control condition, such as canny.
134 |
135 | ```python
136 | # examples/anime_canny/run_with_pp.sh
137 | python run_controlnext.py --pretrained_model_name_or_path "neta-art/neta-xl-2.0" \
138 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \
139 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \
140 | --controlnet_scale 1.0 \
141 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
142 | --validation_prompt "3d style, photorealistic style, 1girl, arknights, amiya (arknights), solo, white background, upper body, looking at viewer, blush, closed mouth, low ponytail, black jacket, hooded jacket, open jacket, hood down, blue neckwear" \
143 | --negative_prompt "worst quality, abstract, clumsy pose, deformed hand, fused fingers, extra digits, fewer digits, fewer fingers, extra fingers, extra arm, missing arm, extra leg, missing leg, signature, artist name, multi views, disfigured, ugly" \
144 | --validation_image "examples/anime_canny/image_0.png" \ # input image (not canny)
145 | --validation_image_processor "canny" \ # preprocess `validation_image` to canny condition
146 | --output_dir "examples/anime_canny" \
147 | --load_weight_increasement
148 | ```
149 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/anime_canny/condition_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/anime_canny/condition_0.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/anime_canny/eval_img/AAM.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/anime_canny/eval_img/AAM.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/anime_canny/eval_img/NetaXLV2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/anime_canny/eval_img/NetaXLV2.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/anime_canny/image_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/anime_canny/image_0.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/anime_canny/run.sh:
--------------------------------------------------------------------------------
1 | python run_controlnext.py --pretrained_model_name_or_path "neta-art/neta-xl-2.0" \
2 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \
3 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \
4 | --controlnet_scale 1.0 \
5 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
6 | --validation_prompt "3d style, photorealistic style, 1girl, arknights, amiya (arknights), solo, white background, upper body, looking at viewer, blush, closed mouth, low ponytail, black jacket, hooded jacket, open jacket, hood down, blue neckwear" \
7 | --negative_prompt "worst quality, abstract, clumsy pose, deformed hand, fused fingers, extra digits, fewer digits, fewer fingers, extra fingers, extra arm, missing arm, extra leg, missing leg, signature, artist name, multi views, disfigured, ugly" \
8 | --validation_image "examples/anime_canny/condition_0.jpg" \
9 | --output_dir "examples/anime_canny" \
10 | --load_weight_increasement \
11 | --use_safetensors \
12 | --variant fp16
13 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/anime_canny/run_with_pp.sh:
--------------------------------------------------------------------------------
1 | python run_controlnext.py --pretrained_model_name_or_path "neta-art/neta-xl-2.0" \
2 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \
3 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \
4 | --controlnet_scale 1.0 \
5 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
6 | --validation_prompt "3d style, photorealistic style, 1girl, arknights, amiya (arknights), solo, white background, upper body, looking at viewer, blush, closed mouth, low ponytail, black jacket, hooded jacket, open jacket, hood down, blue neckwear" \
7 | --negative_prompt "worst quality, abstract, clumsy pose, deformed hand, fused fingers, extra digits, fewer digits, fewer fingers, extra fingers, extra arm, missing arm, extra leg, missing leg, signature, artist name, multi views, disfigured, ugly" \
8 | --validation_image "examples/anime_canny/image_0.jpg" \
9 | --validation_image_processor "canny" \
10 | --output_dir "examples/anime_canny" \
11 | --load_weight_increasement \
12 | --use_safetensors \
13 | --variant fp16
14 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/demo/demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo1.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/demo/demo2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo2.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/demo/demo3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo3.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/demo/demo4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo4.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/demo/demo5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo5.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/vidit_depth/condition_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/vidit_depth/condition_0.png
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/vidit_depth/eval_img/StableDiffusionXL.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/vidit_depth/eval_img/StableDiffusionXL.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/vidit_depth/eval_img/StableDiffusionXL_GlassSculpturesLora.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/vidit_depth/eval_img/StableDiffusionXL_GlassSculpturesLora.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/examples/vidit_depth/run.sh:
--------------------------------------------------------------------------------
1 | python run_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \
2 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-vidit-depth" \
3 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-vidit-depth" \
4 | --controlnet_scale 1.0 \
5 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \
6 | --validation_prompt "a diamond tower in the middle of a lava lake" \
7 | --validation_image "examples/vidit_depth/condition_0.png" \
8 | --output_dir "examples/vidit_depth" \
9 | --width 1024 \
10 | --height 1024 \
11 | --load_weight_increasement \
12 | --variant fp16
13 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | opencv-python
3 | pillow
4 | numpy
5 | transformers
6 | diffusers
7 | safetensors
8 | peft
9 | einops
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from PIL import Image
4 |
5 |
6 | def get_extractor(extractor_name):
7 | if extractor_name is None:
8 | return None
9 | if extractor_name not in EXTRACTORS:
10 | raise ValueError(f"Extractor {extractor_name} is not supported.")
11 | return EXTRACTORS[extractor_name]
12 |
13 |
14 | def canny_extractor(image: Image.Image, threshold1=None, threshold2=None) -> Image.Image:
15 | image = np.array(image)
16 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
17 | v = np.median(gray)
18 |
19 | sigma = 0.33
20 | threshold1 = threshold1 or int(max(0, (1.0 - sigma) * v))
21 | threshold2 = threshold2 or int(min(255, (1.0 + sigma) * v))
22 |
23 | edges = cv2.Canny(gray, threshold1, threshold2)
24 | edges = Image.fromarray(edges).convert("RGB")
25 | return edges
26 |
27 |
28 | def depth_extractor(image: Image.Image):
29 | raise NotImplementedError("Depth extractor is not implemented yet.")
30 |
31 |
32 | def pose_extractor(image: Image.Image):
33 | raise NotImplementedError("Pose extractor is not implemented yet.")
34 |
35 |
36 | EXTRACTORS = {
37 | "canny": canny_extractor,
38 | }
39 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import torch
4 | from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel
5 | from safetensors.torch import load_file
6 | from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline
7 | from models.unet import UNet2DConditionModel
8 | from models.controlnet import ControlNetModel
9 | from . import utils
10 |
11 | UNET_CONFIG = {
12 | "act_fn": "silu",
13 | "addition_embed_type": "text_time",
14 | "addition_embed_type_num_heads": 64,
15 | "addition_time_embed_dim": 256,
16 | "attention_head_dim": [
17 | 5,
18 | 10,
19 | 20
20 | ],
21 | "block_out_channels": [
22 | 320,
23 | 640,
24 | 1280
25 | ],
26 | "center_input_sample": False,
27 | "class_embed_type": None,
28 | "class_embeddings_concat": False,
29 | "conv_in_kernel": 3,
30 | "conv_out_kernel": 3,
31 | "cross_attention_dim": 2048,
32 | "cross_attention_norm": None,
33 | "down_block_types": [
34 | "DownBlock2D",
35 | "CrossAttnDownBlock2D",
36 | "CrossAttnDownBlock2D"
37 | ],
38 | "downsample_padding": 1,
39 | "dual_cross_attention": False,
40 | "encoder_hid_dim": None,
41 | "encoder_hid_dim_type": None,
42 | "flip_sin_to_cos": True,
43 | "freq_shift": 0,
44 | "in_channels": 4,
45 | "layers_per_block": 2,
46 | "mid_block_only_cross_attention": None,
47 | "mid_block_scale_factor": 1,
48 | "mid_block_type": "UNetMidBlock2DCrossAttn",
49 | "norm_eps": 1e-05,
50 | "norm_num_groups": 32,
51 | "num_attention_heads": None,
52 | "num_class_embeds": None,
53 | "only_cross_attention": False,
54 | "out_channels": 4,
55 | "projection_class_embeddings_input_dim": 2816,
56 | "resnet_out_scale_factor": 1.0,
57 | "resnet_skip_time_act": False,
58 | "resnet_time_scale_shift": "default",
59 | "sample_size": 128,
60 | "time_cond_proj_dim": None,
61 | "time_embedding_act_fn": None,
62 | "time_embedding_dim": None,
63 | "time_embedding_type": "positional",
64 | "timestep_post_act": None,
65 | "transformer_layers_per_block": [
66 | 1,
67 | 2,
68 | 10
69 | ],
70 | "up_block_types": [
71 | "CrossAttnUpBlock2D",
72 | "CrossAttnUpBlock2D",
73 | "UpBlock2D"
74 | ],
75 | "upcast_attention": None,
76 | "use_linear_projection": True
77 | }
78 |
79 | CONTROLNET_CONFIG = {
80 | 'in_channels': [128, 128],
81 | 'out_channels': [128, 256],
82 | 'groups': [4, 8],
83 | 'time_embed_dim': 256,
84 | 'final_out_channels': 320,
85 | '_use_default_values': ['time_embed_dim', 'groups', 'in_channels', 'final_out_channels', 'out_channels']
86 | }
87 |
88 |
89 | def get_pipeline(
90 | pretrained_model_name_or_path,
91 | unet_model_name_or_path,
92 | controlnet_model_name_or_path,
93 | vae_model_name_or_path=None,
94 | lora_path=None,
95 | load_weight_increasement=False,
96 | enable_xformers_memory_efficient_attention=False,
97 | revision=None,
98 | variant=None,
99 | hf_cache_dir=None,
100 | use_safetensors=True,
101 | device=None,
102 | ):
103 | pipeline_init_kwargs = {}
104 |
105 | print(f"loading unet from {pretrained_model_name_or_path}")
106 | if os.path.isfile(pretrained_model_name_or_path):
107 | # load unet from local checkpoint
108 | unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path)
109 | unet_sd = utils.extract_unet_state_dict(unet_sd)
110 | unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
111 | unet = UNet2DConditionModel.from_config(UNET_CONFIG)
112 | unet.load_state_dict(unet_sd, strict=True)
113 | else:
114 | unet = UNet2DConditionModel.from_pretrained(
115 | pretrained_model_name_or_path,
116 | cache_dir=hf_cache_dir,
117 | variant=variant,
118 | torch_dtype=torch.float16,
119 | use_safetensors=use_safetensors,
120 | subfolder="unet",
121 | )
122 | unet = unet.to(dtype=torch.float16)
123 | pipeline_init_kwargs["unet"] = unet
124 |
125 | if vae_model_name_or_path is not None:
126 | print(f"loading vae from {vae_model_name_or_path}")
127 | vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device)
128 | pipeline_init_kwargs["vae"] = vae
129 |
130 | if controlnet_model_name_or_path is not None:
131 | pipeline_init_kwargs["controlnet"] = ControlNetModel.from_config(CONTROLNET_CONFIG).to(device, dtype=torch.float32) # init
132 |
133 | print(f"loading pipeline from {pretrained_model_name_or_path}")
134 | if os.path.isfile(pretrained_model_name_or_path):
135 | pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file(
136 | pretrained_model_name_or_path,
137 | use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"),
138 | local_files_only=True,
139 | cache_dir=hf_cache_dir,
140 | **pipeline_init_kwargs,
141 | )
142 | else:
143 | pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained(
144 | pretrained_model_name_or_path,
145 | revision=revision,
146 | variant=variant,
147 | use_safetensors=use_safetensors,
148 | cache_dir=hf_cache_dir,
149 | **pipeline_init_kwargs,
150 | )
151 |
152 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
153 | if unet_model_name_or_path is not None:
154 | print(f"loading controlnext unet from {unet_model_name_or_path}")
155 | pipeline.load_controlnext_unet_weights(
156 | unet_model_name_or_path,
157 | load_weight_increasement=load_weight_increasement,
158 | use_safetensors=True,
159 | torch_dtype=torch.float16,
160 | cache_dir=hf_cache_dir,
161 | )
162 | if controlnet_model_name_or_path is not None:
163 | print(f"loading controlnext controlnet from {controlnet_model_name_or_path}")
164 | pipeline.load_controlnext_controlnet_weights(
165 | controlnet_model_name_or_path,
166 | use_safetensors=True,
167 | torch_dtype=torch.float32,
168 | cache_dir=hf_cache_dir,
169 | )
170 | pipeline.set_progress_bar_config()
171 | pipeline = pipeline.to(device, dtype=torch.float16)
172 |
173 | if lora_path is not None:
174 | pipeline.load_lora_weights(lora_path)
175 | if enable_xformers_memory_efficient_attention:
176 | pipeline.enable_xformers_memory_efficient_attention()
177 |
178 | gc.collect()
179 | if str(device) == 'cuda' and torch.cuda.is_available():
180 | torch.cuda.empty_cache()
181 |
182 | return pipeline
183 |
184 |
185 | def get_scheduler(
186 | scheduler_name,
187 | scheduler_config,
188 | ):
189 | if scheduler_name == 'Euler A':
190 | from diffusers.schedulers import EulerAncestralDiscreteScheduler
191 | scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
192 | elif scheduler_name == 'UniPC':
193 | from diffusers.schedulers import UniPCMultistepScheduler
194 | scheduler = UniPCMultistepScheduler.from_config(scheduler_config)
195 | elif scheduler_name == 'Euler':
196 | from diffusers.schedulers import EulerDiscreteScheduler
197 | scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
198 | elif scheduler_name == 'DDIM':
199 | from diffusers.schedulers import DDIMScheduler
200 | scheduler = DDIMScheduler.from_config(scheduler_config)
201 | elif scheduler_name == 'DDPM':
202 | from diffusers.schedulers import DDPMScheduler
203 | scheduler = DDPMScheduler.from_config(scheduler_config)
204 | else:
205 | raise ValueError(f"Unknown scheduler: {scheduler_name}")
206 | return scheduler
207 |
--------------------------------------------------------------------------------
/ControlNeXt-SDXL/utils/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple, Union, Optional
3 |
4 |
5 | def make_unet_conversion_map():
6 | unet_conversion_map_layer = []
7 |
8 | for i in range(3): # num_blocks is 3 in sdxl
9 | # loop over downblocks/upblocks
10 | for j in range(2):
11 | # loop over resnets/attentions for downblocks
12 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
13 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
14 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
15 |
16 | if i < 3:
17 | # no attention layers in down_blocks.3
18 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
19 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
20 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
21 |
22 | for j in range(3):
23 | # loop over resnets/attentions for upblocks
24 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
25 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
26 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
27 |
28 | # if i > 0: commentout for sdxl
29 | # no attention layers in up_blocks.0
30 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
31 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
32 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
33 |
34 | if i < 3:
35 | # no downsample in down_blocks.3
36 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
37 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
38 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
39 |
40 | # no upsample in up_blocks.3
41 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
42 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
43 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
44 |
45 | hf_mid_atn_prefix = "mid_block.attentions.0."
46 | sd_mid_atn_prefix = "middle_block.1."
47 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
48 |
49 | for j in range(2):
50 | hf_mid_res_prefix = f"mid_block.resnets.{j}."
51 | sd_mid_res_prefix = f"middle_block.{2*j}."
52 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
53 |
54 | unet_conversion_map_resnet = [
55 | # (stable-diffusion, HF Diffusers)
56 | ("in_layers.0.", "norm1."),
57 | ("in_layers.2.", "conv1."),
58 | ("out_layers.0.", "norm2."),
59 | ("out_layers.3.", "conv2."),
60 | ("emb_layers.1.", "time_emb_proj."),
61 | ("skip_connection.", "conv_shortcut."),
62 | ]
63 |
64 | unet_conversion_map = []
65 | for sd, hf in unet_conversion_map_layer:
66 | if "resnets" in hf:
67 | for sd_res, hf_res in unet_conversion_map_resnet:
68 | unet_conversion_map.append((sd + sd_res, hf + hf_res))
69 | else:
70 | unet_conversion_map.append((sd, hf))
71 |
72 | for j in range(2):
73 | hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
74 | sd_time_embed_prefix = f"time_embed.{j*2}."
75 | unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
76 |
77 | for j in range(2):
78 | hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
79 | sd_label_embed_prefix = f"label_emb.0.{j*2}."
80 | unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
81 |
82 | unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
83 | unet_conversion_map.append(("out.0.", "conv_norm_out."))
84 | unet_conversion_map.append(("out.2.", "conv_out."))
85 |
86 | return unet_conversion_map
87 |
88 |
89 | def convert_unet_state_dict(src_sd, conversion_map):
90 | converted_sd = {}
91 | for src_key, value in src_sd.items():
92 | src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
93 | while len(src_key_fragments) > 0:
94 | src_key_prefix = ".".join(src_key_fragments) + "."
95 | if src_key_prefix in conversion_map:
96 | converted_prefix = conversion_map[src_key_prefix]
97 | converted_key = converted_prefix + src_key[len(src_key_prefix):]
98 | converted_sd[converted_key] = value
99 | break
100 | src_key_fragments.pop(-1)
101 | assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
102 |
103 | return converted_sd
104 |
105 |
106 | def convert_sdxl_unet_state_dict_to_diffusers(sd):
107 | unet_conversion_map = make_unet_conversion_map()
108 |
109 | conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
110 | return convert_unet_state_dict(sd, conversion_dict)
111 |
112 |
113 | def extract_unet_state_dict(state_dict):
114 | unet_sd = {}
115 | UNET_KEY_PREFIX = "model.diffusion_model."
116 | for k, v in state_dict.items():
117 | if k.startswith(UNET_KEY_PREFIX):
118 | unet_sd[k[len(UNET_KEY_PREFIX):]] = v
119 | return unet_sd
120 |
121 |
122 | def log_model_info(model, name):
123 | sd = model.state_dict() if hasattr(model, "state_dict") else model
124 | print(
125 | f"{name}:",
126 | f" number of parameters: {sum(p.numel() for p in sd.values())}",
127 | f" dtype: {sd[next(iter(sd))].dtype}",
128 | sep='\n'
129 | )
130 |
131 |
132 | def around_reso(img_w, img_h, reso: Union[Tuple[int, int], int], divisible: Optional[int] = None, max_width=None, max_height=None) -> Tuple[int, int]:
133 | r"""
134 | w*h = reso*reso
135 | w/h = img_w/img_h
136 | => w = img_ar*h
137 | => img_ar*h^2 = reso
138 | => h = sqrt(reso / img_ar)
139 | """
140 | reso = reso if isinstance(reso, tuple) else (reso, reso)
141 | divisible = divisible or 1
142 | if img_w * img_h <= reso[0] * reso[1] and (not max_width or img_w <= max_width) and (not max_height or img_h <= max_height) and img_w % divisible == 0 and img_h % divisible == 0:
143 | return (img_w, img_h)
144 | img_ar = img_w / img_h
145 | around_h = math.sqrt(reso[0]*reso[1] / img_ar)
146 | around_w = img_ar * around_h // divisible * divisible
147 | if max_width and around_w > max_width:
148 | around_h = around_h * max_width // around_w
149 | around_w = max_width
150 | elif max_height and around_h > max_height:
151 | around_w = around_w * max_height // around_h
152 | around_h = max_height
153 | around_h = min(around_h, max_height) if max_height else around_h
154 | around_w = min(around_w, max_width) if max_width else around_w
155 | around_h = int(around_h // divisible * divisible)
156 | around_w = int(around_w // divisible * divisible)
157 | return (around_w, around_h)
158 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/README.md:
--------------------------------------------------------------------------------
1 | # 🌀 ControlNeXt-SVD-v2-Training
2 |
3 | # Important
4 |
5 | I found that sometimes, when I change the version of dependencies, the training may not converge at all. I haven't identified the reason yet, but I've listed all our dependencies in the [requirements.txt](./requirements.txt) file. It's a bit detailed, but you can focus on the key dependencies like `torch`, `deepspeed`, `diffusers`, `accelerate`... When issues arise, checking these first may help. (We use: `diffusers==0.25.0`)
6 |
7 | If you find the differences for the training and inference scripts, such as the `import path`, please refer to the training script!
8 |
9 | ## Main
10 |
11 | Due to privacy concerns, we are unable to release certain resources, such as the training data and the SD3-based model. However, we are committed to sharing as much as possible. If you find this repository helpful, please consider giving us a star or citing our work!
12 |
13 | The training scripts are intended for users with a basic understanding of `Python` and `Diffusers`. Therefore, we will not provide every detail. If you have any questions, please refer to the code first. Thank you! If you encounter any bugs, please contact us and let us know.
14 |
15 | ## Experiences
16 |
17 | We share more training experiences in the [Issue](https://github.com/dvlab-research/ControlNeXt/issues/14#issuecomment-2290450333) and [There](../experiences.md).
18 | We spent a lot of time to find these. Now share with all of you. May these will help you!
19 |
20 |
21 |
22 | ## Training script
23 |
24 | ```bash
25 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file ./deepspeed.yaml train_svd.py \
26 | --pretrained_model_name_or_path=stabilityai/stable-video-diffusion-img2vid-xt-1-1 \
27 | --output_dir= $PATH_TO_THE_SAVE_DIR \
28 | --dataset_type="ubc" \
29 | --meta_info_path=$PATH_TO_THE_META_INFO_FILE_FOR_DATASET \
30 | --validation_image_folder=$PATH_TO_THE_GROUND_TRUTH_DIR_FOR_VALIDATION \
31 | --validation_control_folder=$PATH_TO_THE_POSE_DIR_FOR_VALIDATION \
32 | --validation_image=$PATH_TO_THE_REFERENCE_IMAGE_FILE_FOR_VALIDATION \
33 | --width=576 \
34 | --height=1024 \
35 | --lr_warmup_steps 500 \
36 | --sample_n_frames 21 \
37 | --interval_frame 3 \
38 | --learning_rate=1e-5 \
39 | --per_gpu_batch_size=1 \
40 | --num_train_epochs=6000 \
41 | --mixed_precision="bf16" \
42 | --gradient_accumulation_steps=1 \
43 | --checkpointing_steps=2000 \
44 | --validation_steps=500 \
45 | --gradient_checkpointing \
46 | --checkpoints_total_limit 4
47 |
48 | # For Resume
49 | --controlnet_model_name_or_path $PATH_TO_THE_CONTROLNEXT_WEIGHT
50 | --unet_model_name_or_path $PATH_TO_THE_UNET_WEIGHT
51 | ```
52 |
53 | We set `--num_train_epochs=6000` to ensure no stopped training, but you can stop the process at any point when you believe the results are satisfactory.
54 |
55 | ## Training validation
56 |
57 | Please compile the data for validation like:
58 | ```
59 | ├───ControlNeXt-SVD-v2-Training
60 | └─── ...
61 | ├───validation
62 | | ├───ground_truth
63 | | | ├───0.png
64 | | | ├─── ...
65 | | | └───13.png
66 | | |
67 | | └───pose
68 | | | ├───0.png
69 | | | ├─── ...
70 | | | └───13.png
71 | | |
72 | | └───reference_image.png
73 | |
74 | └─── ...
75 | ```
76 |
77 | And then replace the `path` to:
78 | ```bash
79 | --validation_image_folder=$PATH_TO_THE_GROUND_TRUTH_DIR_FOR_VALIDATION \
80 | --validation_control_folder=$PATH_TO_THE_POSE_DIR_FOR_VALIDATION \
81 | --validation_image=$PATH_TO_THE_REFERENCE_IMAGE_FILE_FOR_VALIDATION \
82 | ```
83 |
84 | ## DeepSpeed
85 |
86 | When using `DeepSpeed` for training, you’ll notice the use of a `DeepSpeedWrapperModel` in [training script](train_svd.py#L837). This wrapper is necessary because DeepSpeed supports only a single model for parallel training, allowing us to encapsulate different modules, including ControlNet and UNet.
87 |
88 | To perform inference with your trained weights, follow these steps:
89 |
90 | 1. Convert the generated weights to a `.bin` file using `zero_to_fp32.py` generated by DeepSpeed (Under the generated weight directory). This will create a file named `pytorch_model.bin`.
91 | 2. Utilize the script [unwrap_deepspeed.py](utils/train_svd.py) to separate the modules into distinct dictionaries for ControlNet and UNet.
92 | 3. Provide the paths to these weights in your inference script.
93 |
94 | ## Meta info
95 |
96 | Please construct the training dataset and provide a list of the data entries in a .json file. We give an example in `meta_info_example/meta_info.json` (the data list) and `meta_info_example/meta_info/1.json`(Detailed meta information for each single video recoarding the position and score):
97 |
98 | `meta_info.json`
99 | ```json
100 | [
101 | {
102 | "video_path": "PATH_TO_THE_SOURCE_VIDEO",
103 | "guide_path": "PATH_TO_THE_CORESEPONDING_POSE_VIDEO",
104 | "meta_info": "PATH_TO_THE_JSON_FILE_RECORD_THE_DETAILED_DETECTION_RESULTS(we give an example in meta_info/1.json)"
105 | }
106 | ...
107 | ]
108 | ```
109 |
110 | ## GPU memory
111 |
112 | It requires substantial memory for training, as we use a high resolution and long frame batches to achieve optimal performance. However, you can implement certain techniques to reduce memory consumption, although they may result in a trade-off with performance.
113 |
114 | > 1. Adopt bf16 and fp16 (we have already implemented this).
115 | > 2. Use DeepSpeed and distributed training across multiple machines.
116 | > 3. Reduce the resolution by set `--width=576 --height=1024 `, such as `512*768`
117 | > 4. Reduce the `--sample_n_frames`
118 |
119 |
120 | ### If you find this work helpful, please consider citing:
121 | ```
122 | @article{peng2024controlnext,
123 | title={ControlNeXt: Powerful and Efficient Control for Image and Video Generation},
124 | author={Peng, Bohao and Wang, Jian and Zhang, Yuechen and Li, Wenbo and Yang, Ming-Chang and Jia, Jiaya},
125 | journal={arXiv preprint arXiv:2408.06070},
126 | year={2024}
127 | }
128 | ```
129 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/deepspeed.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_accumulation_steps: 1
4 | gradient_clipping: 1.0
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | fsdp_config: {}
11 | machine_rank: 0
12 | main_process_ip: null
13 | main_process_port: null
14 | main_training_function: main
15 | mixed_precision: bf16
16 | num_machines: 1
17 | num_processes: 8
18 | use_cpu: false
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/meta_info_example/meta_info.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "video_path": "PATH_TO_THE_SOURCE_VIDEO",
4 | "guide_path": "PATH_TO_THE_CORESEPONDING_POSE_VIDEO",
5 | "meta_info": "PATH_TO_THE_JSON_FILE_RECORD_THE_DETAILED_DETECTION_RESULTS(we give an example in meta_info/1.json)"
6 | }
7 | ]
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/models/controlnext_vid_svd.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from diffusers.configuration_utils import ConfigMixin, register_to_config
7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
8 | from diffusers.models.modeling_utils import ModelMixin
9 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D
10 |
11 |
12 | class ControlNeXtSDVModel(ModelMixin, ConfigMixin):
13 | _supports_gradient_checkpointing = True
14 |
15 | @register_to_config
16 | def __init__(
17 | self,
18 | time_embed_dim = 256,
19 | in_channels = [128, 128],
20 | out_channels = [128, 256],
21 | groups = [4, 8]
22 | ):
23 | super().__init__()
24 |
25 | self.time_proj = Timesteps(128, True, downscale_freq_shift=0)
26 | self.time_embedding = TimestepEmbedding(128, time_embed_dim)
27 | self.embedding = nn.Sequential(
28 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
29 | nn.GroupNorm(2, 64),
30 | nn.ReLU(),
31 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
32 | nn.GroupNorm(2, 64),
33 | nn.ReLU(),
34 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
35 | nn.GroupNorm(2, 128),
36 | nn.ReLU(),
37 | )
38 |
39 | self.down_res = nn.ModuleList()
40 | self.down_sample = nn.ModuleList()
41 | for i in range(len(in_channels)):
42 | self.down_res.append(
43 | ResnetBlock2D(
44 | in_channels=in_channels[i],
45 | out_channels=out_channels[i],
46 | temb_channels=time_embed_dim,
47 | groups=groups[i]
48 | ),
49 | )
50 | self.down_sample.append(
51 | Downsample2D(
52 | out_channels[i],
53 | use_conv=True,
54 | out_channels=out_channels[i],
55 | padding=1,
56 | name="op",
57 | )
58 | )
59 |
60 | self.mid_convs = nn.ModuleList()
61 | self.mid_convs.append(nn.Sequential(
62 | nn.Conv2d(
63 | in_channels=out_channels[-1],
64 | out_channels=out_channels[-1],
65 | kernel_size=3,
66 | stride=1,
67 | padding=1
68 | ),
69 | nn.ReLU(),
70 | nn.GroupNorm(8, out_channels[-1]),
71 | nn.Conv2d(
72 | in_channels=out_channels[-1],
73 | out_channels=out_channels[-1],
74 | kernel_size=3,
75 | stride=1,
76 | padding=1
77 | ),
78 | nn.GroupNorm(8, out_channels[-1]),
79 | ))
80 | self.mid_convs.append(
81 | nn.Conv2d(
82 | in_channels=out_channels[-1],
83 | out_channels=320,
84 | kernel_size=1,
85 | stride=1,
86 | ))
87 |
88 | self.scale = 1.
89 |
90 | def _set_gradient_checkpointing(self, module, value=False):
91 | if hasattr(module, "gradient_checkpointing"):
92 | module.gradient_checkpointing = value
93 |
94 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
95 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
96 | """
97 | Sets the attention processor to use [feed forward
98 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
99 |
100 | Parameters:
101 | chunk_size (`int`, *optional*):
102 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
103 | over each tensor of dim=`dim`.
104 | dim (`int`, *optional*, defaults to `0`):
105 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
106 | or dim=1 (sequence length).
107 | """
108 | if dim not in [0, 1]:
109 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
110 |
111 | # By default chunk size is 1
112 | chunk_size = chunk_size or 1
113 |
114 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
115 | if hasattr(module, "set_chunk_feed_forward"):
116 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
117 |
118 | for child in module.children():
119 | fn_recursive_feed_forward(child, chunk_size, dim)
120 |
121 | for module in self.children():
122 | fn_recursive_feed_forward(module, chunk_size, dim)
123 |
124 | def forward(
125 | self,
126 | sample: torch.FloatTensor,
127 | timestep: Union[torch.Tensor, float, int],
128 | ):
129 |
130 | timesteps = timestep
131 | if not torch.is_tensor(timesteps):
132 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
133 | # This would be a good case for the `match` statement (Python 3.10+)
134 | is_mps = sample.device.type == "mps"
135 | if isinstance(timestep, float):
136 | dtype = torch.float32 if is_mps else torch.float64
137 | else:
138 | dtype = torch.int32 if is_mps else torch.int64
139 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
140 | elif len(timesteps.shape) == 0:
141 | timesteps = timesteps[None].to(sample.device)
142 |
143 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
144 | batch_size, num_frames = sample.shape[:2]
145 | timesteps = timesteps.expand(batch_size)
146 |
147 | t_emb = self.time_proj(timesteps)
148 |
149 | # `Timesteps` does not contain any weights and will always return f32 tensors
150 | # but time_embedding might actually be running in fp16. so we need to cast here.
151 | # there might be better ways to encapsulate this.
152 | t_emb = t_emb.to(dtype=sample.dtype)
153 |
154 | emb_batch = self.time_embedding(t_emb)
155 |
156 | # Flatten the batch and frames dimensions
157 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
158 | sample = sample.flatten(0, 1)
159 | # Repeat the embeddings num_video_frames times
160 | # emb: [batch, channels] -> [batch * frames, channels]
161 | emb = emb_batch.repeat_interleave(num_frames, dim=0)
162 |
163 | sample = self.embedding(sample)
164 |
165 | for res, downsample in zip(self.down_res, self.down_sample):
166 | sample = res(sample, emb)
167 | sample = downsample(sample, emb)
168 |
169 | sample = self.mid_convs[0](sample) + sample
170 | sample = self.mid_convs[1](sample)
171 |
172 | return {
173 | 'output': sample,
174 | 'scale': self.scale,
175 | }
176 |
177 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.31.0
3 | addict==2.4.0
4 | aiofiles==23.2.1
5 | aiohttp==3.9.5
6 | aiosignal==1.3.1
7 | albumentations==1.3.1
8 | altair==5.3.0
9 | annotated-types==0.7.0
10 | antlr4-python3-runtime==4.8
11 | anyio==4.4.0
12 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
13 | async-timeout==4.0.3
14 | attrs==23.2.0
15 | av==12.0.0
16 | azureml==0.2.7
17 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
18 | basicsr==1.4.2
19 | blessed==1.20.0
20 | boto3==1.34.123
21 | botocore==1.34.123
22 | braceexpand==0.1.7
23 | cachetools==5.3.3
24 | certifi==2024.6.2
25 | cffi==1.16.0
26 | chardet==5.2.0
27 | charset-normalizer==3.3.2
28 | chumpy==0.70
29 | clean-fid==0.1.35
30 | click==8.1.7
31 | clip-anytorch==2.6.0
32 | cmake==3.29.3
33 | coloredlogs==15.0.1
34 | colorlog==6.8.2
35 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
36 | contourpy==1.1.1
37 | cycler==0.12.1
38 | datasets==2.19.1
39 | dctorch==0.1.2
40 | debugpy @ file:///croot/debugpy_1690905042057/work
41 | decorator==4.4.2
42 | decord==0.6.0
43 | deepspeed==0.14.5
44 | -e git+https://github.com/huggingface/diffusers.git@983dec3bf787c064ed57f2621c9b7375d443f746#egg=diffusers
45 | dill==0.3.8
46 | dnspython==2.6.1
47 | docker-pycreds==0.4.0
48 | editor==1.6.6
49 | einops==0.8.0
50 | einops-exts==0.0.4
51 | email_validator==2.1.1
52 | embreex==2.17.7.post4
53 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
54 | exceptiongroup==1.2.1
55 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
56 | facexlib==0.3.0
57 | fairscale==0.4.13
58 | fastapi==0.111.0
59 | fastapi-cli==0.0.4
60 | ffmpy==0.3.2
61 | filelock==3.15.1
62 | filterpy==1.4.5
63 | flatbuffers==24.3.25
64 | fonttools==4.52.4
65 | frozenlist==1.4.1
66 | fsspec==2024.6.0
67 | ftfy==6.2.0
68 | future==1.0.0
69 | gitdb==4.0.11
70 | GitPython==3.1.43
71 | google-auth==2.29.0
72 | google-auth-oauthlib==1.0.0
73 | gradio==4.31.5
74 | gradio_client==0.16.4
75 | gradio_imageslider==0.0.18
76 | grpcio==1.64.0
77 | h11==0.14.0
78 | hjson==3.1.0
79 | httpcore==1.0.5
80 | httptools==0.6.1
81 | httpx==0.27.0
82 | huggingface-hub==0.23.3
83 | humanfriendly==10.0
84 | idna==3.7
85 | imageio==2.34.1
86 | imageio-ffmpeg==0.4.9
87 | importlib_metadata==7.1.0
88 | importlib_resources==6.4.0
89 | inquirer==3.3.0
90 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
91 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work
92 | jax==0.4.13
93 | jaxlib==0.4.13
94 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
95 | Jinja2==3.1.4
96 | jmespath==1.0.1
97 | joblib==1.4.2
98 | jsonmerge==1.9.2
99 | jsonschema==4.22.0
100 | jsonschema-specifications==2023.12.1
101 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work
102 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257397447/work
103 | k-diffusion==0.1.1.post1
104 | kiwisolver==1.4.5
105 | kornia==0.7.2
106 | kornia_rs==0.1.3
107 | lazy_loader==0.4
108 | lightning-utilities==0.11.2
109 | lit==18.1.6
110 | llvmlite==0.41.1
111 | lmdb==1.4.1
112 | lxml==5.2.2
113 | manopth @ file:///home/llm/bhpeng/generation/HandRefiner/MeshGraphormer/manopth
114 | mapbox-earcut==1.0.1
115 | Markdown==3.6
116 | markdown-it-py==3.0.0
117 | MarkupSafe==2.1.5
118 | matplotlib==3.7.5
119 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
120 | mdurl==0.1.2
121 | mediapipe==0.10.0
122 | ml-dtypes==0.2.0
123 | moviepy==1.0.3
124 | mpmath==1.3.0
125 | multidict==6.0.5
126 | multiprocess==0.70.16
127 | mypy-extensions==1.0.0
128 | nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
129 | networkx==3.1
130 | ninja==1.11.1.1
131 | numba==0.58.1
132 | numpy==1.24.4
133 | nvidia-cublas-cu11==11.10.3.66
134 | nvidia-cublas-cu12==12.1.3.1
135 | nvidia-cuda-cupti-cu11==11.7.101
136 | nvidia-cuda-cupti-cu12==12.1.105
137 | nvidia-cuda-nvrtc-cu11==11.7.99
138 | nvidia-cuda-nvrtc-cu12==12.1.105
139 | nvidia-cuda-runtime-cu11==11.7.99
140 | nvidia-cuda-runtime-cu12==12.1.105
141 | nvidia-cudnn-cu11==8.5.0.96
142 | nvidia-cudnn-cu12==8.9.2.26
143 | nvidia-cufft-cu11==10.9.0.58
144 | nvidia-cufft-cu12==11.0.2.54
145 | nvidia-curand-cu11==10.2.10.91
146 | nvidia-curand-cu12==10.3.2.106
147 | nvidia-cusolver-cu11==11.4.0.1
148 | nvidia-cusolver-cu12==11.4.5.107
149 | nvidia-cusparse-cu11==11.7.4.91
150 | nvidia-cusparse-cu12==12.1.0.106
151 | nvidia-ml-py==12.560.30
152 | nvidia-nccl-cu11==2.14.3
153 | nvidia-nccl-cu12==2.20.5
154 | nvidia-nvjitlink-cu12==12.5.40
155 | nvidia-nvtx-cu11==11.7.91
156 | nvidia-nvtx-cu12==12.1.105
157 | oauthlib==3.2.2
158 | omegaconf==2.1.1
159 | onnxruntime-gpu==1.17.1
160 | open-clip-torch==2.22.0
161 | openai-clip==1.0.1
162 | opencv-contrib-python==4.10.0.82
163 | opencv-python==4.9.0.80
164 | opencv-python-headless==4.10.0.82
165 | opt-einsum==3.3.0
166 | orjson==3.10.3
167 | packaging==24.1
168 | pandas==2.0.0
169 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
170 | peft==0.11.1
171 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
172 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
173 | Pillow==9.5.0
174 | pkgutil_resolve_name==1.3.10
175 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
176 | proglog==0.1.10
177 | prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work
178 | protobuf==5.27.2
179 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
180 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
181 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
182 | py-cpuinfo==9.0.0
183 | pyarrow==16.1.0
184 | pyarrow-hotfix==0.6
185 | pyasn1==0.6.0
186 | pyasn1_modules==0.4.0
187 | pycollada==0.8
188 | pycparser==2.22
189 | pydantic==2.7.1
190 | pydantic_core==2.18.2
191 | pydub==0.25.1
192 | pygifsicle==1.0.7
193 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
194 | pyparsing==3.1.2
195 | pyre-extensions==0.0.29
196 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
197 | python-dotenv==1.0.1
198 | python-multipart==0.0.9
199 | pytorch-lightning==1.9.3
200 | pytorch-pretrained-bert==0.6.2
201 | pytz==2024.1
202 | PyWavelets==1.4.1
203 | PyYAML==6.0.1
204 | pyzmq @ file:///croot/pyzmq_1705605076900/work
205 | qudida==0.0.4
206 | readchar==4.1.0
207 | referencing==0.35.1
208 | regex==2024.5.15
209 | requests==2.32.3
210 | requests-oauthlib==2.0.0
211 | rich==13.7.1
212 | rpds-py==0.18.1
213 | rsa==4.9
214 | Rtree==1.2.0
215 | ruff==0.4.5
216 | runs==1.2.2
217 | s3transfer==0.10.1
218 | safetensors==0.4.3
219 | scikit-image==0.21.0
220 | scikit-learn==1.3.2
221 | scipy==1.10.1
222 | semantic-version==2.10.0
223 | sentencepiece==0.2.0
224 | sentry-sdk==2.5.1
225 | setproctitle==1.3.3
226 | shapely==2.0.4
227 | shellingham==1.5.4
228 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
229 | smmap==5.0.1
230 | sniffio==1.3.1
231 | sounddevice==0.4.7
232 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
233 | starlette==0.37.2
234 | support-developer==1.0.5
235 | svg.path==6.3
236 | sympy==1.12.1
237 | tb-nightly==2.14.0a20230808
238 | tensorboard==2.14.0
239 | tensorboard-data-server==0.7.2
240 | tensorboardX==2.6.2.2
241 | threadpoolctl==3.5.0
242 | tifffile==2023.7.10
243 | timm==0.6.13
244 | tokenizers==0.19.1
245 | tomli==2.0.1
246 | tomlkit==0.12.0
247 | toolz==0.12.1
248 | torch==2.3.1
249 | torchdiffeq==0.2.4
250 | torchmetrics==1.4.0.post0
251 | torchsde==0.2.6
252 | torchvision==0.15.1
253 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827257044/work
254 | tqdm==4.66.4
255 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
256 | trampoline==0.1.2
257 | transformers==4.41.2
258 | trimesh==3.23.5
259 | triton==2.3.1
260 | typer==0.12.3
261 | typing-inspect==0.9.0
262 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work
263 | tzdata==2024.1
264 | ujson==5.10.0
265 | urllib3==2.2.1
266 | uvicorn==0.30.0
267 | uvloop==0.19.0
268 | wandb==0.17.1
269 | watchfiles==0.22.0
270 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
271 | webdataset==0.2.86
272 | websockets==11.0.3
273 | Werkzeug==3.0.3
274 | xformers==0.0.26.post1
275 | xmod==1.8.1
276 | xxhash==3.4.1
277 | yacs==0.1.8
278 | yapf==0.40.2
279 | yarl==1.9.4
280 | zipp==3.19.0
281 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/script.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file ./deepspeed.yaml train_svd.py \
5 | --pretrained_model_name_or_path=stabilityai/stable-video-diffusion-img2vid-xt-1-1 \
6 | --output_dir= $PATH_TO_THE_SAVE_DIR \
7 | --dataset_type="ubc" \
8 | --meta_info_path=$PATH_TO_THE_META_INFO_FILE_FOR_DATASET \
9 | --validation_image_folder=$PATH_TO_THE_GROUND_TRUTH_DIR_FOR_EVALUATION \
10 | --validation_control_folder=$PATH_TO_THE_POSE_DIR_FOR_EVALUATION \
11 | --validation_image=$PATH_TO_THE_REFERENCE_IMAGE_FILE_FOR_EVALUATION \
12 | --width=576 \
13 | --height=1024 \
14 | --lr_warmup_steps 500 \
15 | --sample_n_frames 14 \
16 | --interval_frame 3 \
17 | --learning_rate=1e-5 \
18 | --per_gpu_batch_size=1 \
19 | --num_train_epochs=6000 \
20 | --mixed_precision="bf16" \
21 | --gradient_accumulation_steps=1 \
22 | --checkpointing_steps=2000 \
23 | --validation_steps=500 \
24 | --gradient_checkpointing \
25 | --checkpoints_total_limit 4
26 |
27 | # For Resume
28 | --controlnet_model_name_or_path $PATH_TO_THE_CONTROLNEXT_WEIGHT
29 | --unet_model_name_or_path $PATH_TO_THE_UNET_WEIGHT
30 |
31 |
32 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os, io, csv, math, random
2 | import numpy as np
3 | from einops import rearrange
4 |
5 | import torch
6 | from decord import VideoReader
7 | import cv2
8 |
9 | import torchvision.transforms as transforms
10 | from torch.utils.data.dataset import Dataset
11 | from utils.util import zero_rank_print
12 | #from torchvision.io import read_image
13 | from PIL import Image
14 | def pil_image_to_numpy(image):
15 | """Convert a PIL image to a NumPy array."""
16 | if image.mode != 'RGB':
17 | image = image.convert('RGB')
18 | return np.array(image)
19 |
20 | def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
21 | """Convert a NumPy image to a PyTorch tensor."""
22 | if images.ndim == 3:
23 | images = images[..., None]
24 | images = torch.from_numpy(images.transpose(0, 3, 1, 2))
25 | return images.float() / 255
26 |
27 |
28 | class WebVid10M(Dataset):
29 | def __init__(
30 | self,
31 | csv_path, video_folder,depth_folder,motion_folder,
32 | sample_size=256, sample_stride=4, sample_n_frames=14,
33 | ):
34 | zero_rank_print(f"loading annotations from {csv_path} ...")
35 | with open(csv_path, 'r') as csvfile:
36 | self.dataset = list(csv.DictReader(csvfile))
37 | self.length = len(self.dataset)
38 | print(f"data scale: {self.length}")
39 | random.shuffle(self.dataset)
40 | self.video_folder = video_folder
41 | self.sample_stride = sample_stride
42 | self.sample_n_frames = sample_n_frames
43 | self.depth_folder = depth_folder
44 | self.motion_values_folder=motion_folder
45 | print("length",len(self.dataset))
46 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
47 | print("sample size",sample_size)
48 | self.pixel_transforms = transforms.Compose([
49 | transforms.RandomHorizontalFlip(),
50 | transforms.Resize(sample_size),
51 | transforms.CenterCrop(sample_size),
52 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
53 | ])
54 |
55 |
56 |
57 |
58 |
59 | def center_crop(self,img):
60 | h, w = img.shape[-2:] # Assuming img shape is [C, H, W] or [B, C, H, W]
61 | min_dim = min(h, w)
62 | top = (h - min_dim) // 2
63 | left = (w - min_dim) // 2
64 | return img[..., top:top+min_dim, left:left+min_dim]
65 |
66 |
67 | def get_batch(self, idx):
68 | def sort_frames(frame_name):
69 | return int(frame_name.split('_')[1].split('.')[0])
70 |
71 |
72 |
73 | while True:
74 | video_dict = self.dataset[idx]
75 | videoid = video_dict['videoid']
76 |
77 | preprocessed_dir = os.path.join(self.video_folder, videoid)
78 | depth_folder = os.path.join(self.depth_folder, videoid)
79 | motion_values_file = os.path.join(self.motion_values_folder, videoid, videoid + "_average_motion.txt")
80 |
81 | if not os.path.exists(depth_folder) or not os.path.exists(motion_values_file):
82 | idx = random.randint(0, len(self.dataset) - 1)
83 | continue
84 |
85 | # Sort and limit the number of image and depth files to 14
86 | image_files = sorted(os.listdir(preprocessed_dir), key=sort_frames)[:14]
87 | depth_files = sorted(os.listdir(depth_folder), key=sort_frames)[:14]
88 |
89 | # Check if there are enough frames for both image and depth
90 | if len(image_files) < 14 or len(depth_files) < 14:
91 | idx = random.randint(0, len(self.dataset) - 1)
92 | continue
93 |
94 | # Load image frames
95 | numpy_images = np.array([pil_image_to_numpy(Image.open(os.path.join(preprocessed_dir, img))) for img in image_files])
96 | pixel_values = numpy_to_pt(numpy_images)
97 |
98 | # Load depth frames
99 | numpy_depth_images = np.array([pil_image_to_numpy(Image.open(os.path.join(depth_folder, df))) for df in depth_files])
100 | depth_pixel_values = numpy_to_pt(numpy_depth_images)
101 |
102 | # Load motion values
103 | with open(motion_values_file, 'r') as file:
104 | motion_values = float(file.read().strip())
105 |
106 | return pixel_values, depth_pixel_values, motion_values
107 |
108 |
109 |
110 |
111 | def __len__(self):
112 | return self.length
113 |
114 | def __getitem__(self, idx):
115 |
116 | #while True:
117 | # try:
118 | pixel_values, depth_pixel_values,motion_values = self.get_batch(idx)
119 | # break
120 | # except Exception as e:
121 | # print(e)
122 | # idx = random.randint(0, self.length - 1)
123 |
124 | pixel_values = self.pixel_transforms(pixel_values)
125 | sample = dict(pixel_values=pixel_values, depth_pixel_values=depth_pixel_values,motion_values=motion_values)
126 | return sample
127 |
128 |
129 |
130 |
131 | if __name__ == "__main__":
132 | from utils.util import save_videos_grid
133 |
134 | dataset = WebVid10M(
135 | csv_path="/data/webvid/results_2M_train.csv",
136 | video_folder="/data/webvid/data/videos",
137 | sample_size=256,
138 | sample_stride=4, sample_n_frames=16,
139 | is_image=True,
140 | )
141 | import pdb
142 | pdb.set_trace()
143 |
144 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
145 | for idx, batch in enumerate(dataloader):
146 | print(batch["pixel_values"].shape, len(batch["text"]))
147 | # for i in range(batch["pixel_values"].shape[0]):
148 | # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/utils/extract_learned_paras.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import argparse
4 | import torch
5 | import os
6 | from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNetModel
7 | from safetensors.torch import save_file, load_file
8 |
9 |
10 | """
11 | python -m utils.extract_learned_paras \
12 | /home/llm/bhpeng/generation/svd-temporal-controlnet/outputs_sdxt_mid_upblocks/ft_on_duqi_after_hands500_checkpoint-200/unet/unet_fp16.bin \
13 | /home/llm/bhpeng/generation/svd-temporal-controlnet/outputs_sdxt_mid_upblocks/ft_on_duqi_after_hands500_checkpoint-200/unet/unet_fp16_increase.bin \
14 | --pretrained_path /home/llm/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid-xt-1-1/snapshots/a423ba0d3e1a94a57ebc68e98691c43104198394
15 | """
16 |
17 |
18 | if __name__ == "__main__":
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("src_path",
22 | type=str,
23 | help="path to the video")
24 | parser.add_argument("dst_path",
25 | type=str,
26 | help="path to the save_dict")
27 | parser.add_argument("--pretrained_path",
28 | type=str,
29 | default="/home/llm/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid/snapshots/ae8391f7321be9ff8941508123715417da827aa4")
30 | parser.add_argument("--save_as_fp32",
31 | action="store_true",)
32 | parser.add_argument("--save_weight_increase",
33 | action="store_true",)
34 | args = parser.parse_args()
35 |
36 | unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(
37 | args.pretrained_path,
38 | subfolder="unet",
39 | low_cpu_mem_usage=True,
40 | variant="fp16",
41 | )
42 | pretrained_state_dict = unet.state_dict()
43 |
44 | if os.path.splitext(args.src_path)[1] == ".bin":
45 | src_state_dict = torch.load(args.src_path)
46 | elif os.path.splitext(args.src_path)[1] == ".safetensors":
47 | src_state_dict = load_file(args.src_path)
48 |
49 | for k in list(src_state_dict.keys()):
50 | src_state_dict[k] = src_state_dict[k].to(pretrained_state_dict[k])
51 | if torch.allclose(src_state_dict[k], pretrained_state_dict[k]):
52 | src_state_dict.pop(k)
53 | continue
54 | if args.save_weight_increase:
55 | src_state_dict[k] = src_state_dict[k] - pretrained_state_dict[k]
56 | if not args.save_as_fp32:
57 | src_state_dict[k] = src_state_dict[k].half()
58 |
59 |
60 | if os.path.splitext(args.dst_path)[1] == ".bin":
61 | torch.save(src_state_dict, args.dst_path)
62 | elif os.path.splitext(args.dst_path)[1] == ".safetensors":
63 | save_file(src_state_dict, args.dst_path)
64 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/utils/extract_vid2img.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import argparse
4 | import random
5 | import os
6 | from decord import VideoReader
7 | import cv2
8 |
9 |
10 | """
11 | python -m utils.extract_vid2img /home/llm/bhpeng/generation/svd-temporal-controlnet/proj/dataset/cropped/v2/users/肚脐小师妹/7296828856386784548.mp4 /home/llm/bhpeng/generation/svd-temporal-controlnet/validation_demo/test/0
12 | """
13 |
14 |
15 | if __name__ == "__main__":
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("video_path",
19 | type=str,
20 | help="path to the video")
21 | parser.add_argument("save_dict",
22 | type=str,
23 | help="path to the save_dict")
24 | parser.add_argument("--interval_frame",
25 | type=int,
26 | default=2)
27 | parser.add_argument("--sample_n_frames",
28 | type=int,
29 | default=1)
30 | args = parser.parse_args()
31 |
32 |
33 | video_path = args.video_path
34 | pose_path = video_path.replace("/users", "/pose")
35 |
36 | save_video_path = os.path.join(args.save_dict, "rgb")
37 | save_pose_path = os.path.join(args.save_dict, "pose")
38 |
39 | if not os.path.exists(save_video_path):
40 | os.makedirs(save_video_path)
41 | if not os.path.exists(save_pose_path):
42 | os.makedirs(save_pose_path)
43 |
44 | vr = VideoReader(video_path)
45 | length = len(vr)
46 | segment_length = args.interval_frame * args.sample_n_frames
47 | assert length >= segment_length, "Too short video..."
48 | bg_frame_id = random.randint(0, length - segment_length)
49 | frame_ids = list(range(bg_frame_id, bg_frame_id + segment_length, args.interval_frame))
50 | for idx, fid in enumerate(frame_ids):
51 | frame = vr[fid].asnumpy()
52 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
53 | frame = cv2.imwrite(os.path.join(save_video_path, "{}.png".format(idx)), frame)
54 |
55 |
56 | vr = VideoReader(pose_path)
57 | for idx, fid in enumerate(frame_ids):
58 | frame = vr[fid].asnumpy()
59 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
60 | frame = cv2.imwrite(os.path.join(save_pose_path, "{}.png".format(idx)), frame)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/utils/unwrap_deepspeed.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import argparse
4 | import torch
5 | import os
6 | from collections import OrderedDict
7 |
8 |
9 | if __name__ == "__main__":
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("checkpoint_dir",
13 | type=str,
14 | help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
15 | args = parser.parse_args()
16 | state_dict = torch.load(os.path.join(args.checkpoint_dir, "pytorch_model.bin"))
17 |
18 | unet_state = OrderedDict()
19 | controlnext_state = OrderedDict()
20 | for name, data in state_dict.items():
21 | model = name.split('.', 1)[0]
22 | module = name.split('.', 1)[1]
23 | if model == 'unet':
24 | unet_state[module] = data
25 | elif model == 'controlnext':
26 | controlnext_state[module] = data
27 |
28 | for model in ['unet', 'controlnext']:
29 | if not os.path.exists(os.path.join(args.checkpoint_dir, model)):
30 | os.makedirs(os.path.join(args.checkpoint_dir, model))
31 |
32 |
33 | torch.save(unet_state, os.path.join(args.checkpoint_dir, "unet", "diffusion_pytorch_model.bin"))
34 | torch.save(controlnext_state, os.path.join(args.checkpoint_dir, "controlnext", "diffusion_pytorch_model.bin"))
35 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2-Training/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from typing import Union
5 |
6 | import torch
7 | import torchvision
8 | import torch.distributed as dist
9 |
10 | from safetensors import safe_open
11 | from tqdm import tqdm
12 | from einops import rearrange
13 |
14 |
15 | def zero_rank_print(s):
16 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
17 |
18 |
19 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
20 | videos = rearrange(videos, "b c t h w -> t b c h w")
21 | outputs = []
22 | for x in videos:
23 | x = torchvision.utils.make_grid(x, nrow=n_rows)
24 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
25 | if rescale:
26 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
27 | x = (x * 255).numpy().astype(np.uint8)
28 | outputs.append(x)
29 |
30 | os.makedirs(os.path.dirname(path), exist_ok=True)
31 | imageio.mimsave(path, outputs, fps=fps)
32 |
33 |
34 | # DDIM Inversion
35 | @torch.no_grad()
36 | def init_prompt(prompt, pipeline):
37 | uncond_input = pipeline.tokenizer(
38 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
39 | return_tensors="pt"
40 | )
41 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
42 | text_input = pipeline.tokenizer(
43 | [prompt],
44 | padding="max_length",
45 | max_length=pipeline.tokenizer.model_max_length,
46 | truncation=True,
47 | return_tensors="pt",
48 | )
49 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
50 | context = torch.cat([uncond_embeddings, text_embeddings])
51 |
52 | return context
53 |
54 |
55 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
56 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
57 | timestep, next_timestep = min(
58 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
59 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
60 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
61 | beta_prod_t = 1 - alpha_prod_t
62 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
63 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
64 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
65 | return next_sample
66 |
67 |
68 | def get_noise_pred_single(latents, t, context, unet):
69 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
70 | return noise_pred
71 |
72 |
73 | @torch.no_grad()
74 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
75 | context = init_prompt(prompt, pipeline)
76 | uncond_embeddings, cond_embeddings = context.chunk(2)
77 | all_latent = [latent]
78 | latent = latent.clone().detach()
79 | for i in tqdm(range(num_inv_steps)):
80 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
81 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
82 | latent = next_step(noise_pred, t, latent, ddim_scheduler)
83 | all_latent.append(latent)
84 | return all_latent
85 |
86 |
87 | @torch.no_grad()
88 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
89 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
90 | return ddim_latents
91 |
92 |
93 | # def load_weights(
94 | # animation_pipeline,
95 | # # motion module
96 | # motion_module_path = "",
97 | # motion_module_lora_configs = [],
98 | # # image layers
99 | # dreambooth_model_path = "",
100 | # lora_model_path = "",
101 | # lora_alpha = 0.8,
102 | # ):
103 | # # 1.1 motion module
104 | # unet_state_dict = {}
105 | # if motion_module_path != "":
106 | # print(f"load motion module from {motion_module_path}")
107 | # motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
108 | # motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
109 | # unet_state_dict.update({name.replace("module.", ""): param for name, param in motion_module_state_dict.items()})
110 |
111 | # missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
112 | # assert len(unexpected) == 0
113 | # del unet_state_dict
114 |
115 | # # if dreambooth_model_path != "":
116 | # # print(f"load dreambooth model from {dreambooth_model_path}")
117 | # # if dreambooth_model_path.endswith(".safetensors"):
118 | # # dreambooth_state_dict = {}
119 | # # with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
120 | # # for key in f.keys():
121 | # # dreambooth_state_dict[key.replace("module.", "")] = f.get_tensor(key)
122 | # # elif dreambooth_model_path.endswith(".ckpt"):
123 | # # dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
124 | # # dreambooth_state_dict = {k.replace("module.", ""): v for k, v in dreambooth_state_dict.items()}
125 |
126 | # # 1. vae
127 | # # converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
128 | # # animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
129 | # # 2. unet
130 | # # converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
131 | # # animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
132 | # # 3. text_model
133 | # # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
134 | # # del dreambooth_state_dict
135 |
136 | # if lora_model_path != "":
137 | # print(f"load lora model from {lora_model_path}")
138 | # assert lora_model_path.endswith(".safetensors")
139 | # lora_state_dict = {}
140 | # with safe_open(lora_model_path, framework="pt", device="cpu") as f:
141 | # for key in f.keys():
142 | # lora_state_dict[key.replace("module.", "")] = f.get_tensor(key)
143 |
144 | # animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
145 | # del lora_state_dict
146 |
147 | # for motion_module_lora_config in motion_module_lora_configs:
148 | # path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
149 | # print(f"load motion LoRA from {path}")
150 |
151 | # motion_lora_state_dict = torch.load(path, map_location="cpu")
152 | # motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
153 | # motion_lora_state_dict = {k.replace("module.", ""): v for k, v in motion_lora_state_dict.items()}
154 |
155 | # animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
156 |
157 | # return animation_pipeline
158 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/README.md:
--------------------------------------------------------------------------------
1 | # 🌀 ControlNeXt-SVD-v2
2 |
3 | This is our implementation of ControlNeXt based on [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). It can be seen as an attempt to replicate the implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) with a more concise and efficient architecture.
4 |
5 | Compared to image generation, video generation poses significantly greater challenges. While direct training of the generation model using our method is feasible, we also employ various engineering strategies to enhance performance. Although they are irrespective of academic algorithms.
6 |
7 |
8 | > Please refer to [Examples](#examples) for further intuitive details.\
9 | > Please refer to [Base model](#base-model) for more details of our used base model. \
10 | > Please refer to [Inference](#inference) for more details regarding installation and inference.\
11 | > Please refer to [Advanced Performance](#advanced-performance) for more details to achieve a better performance.\
12 | > Please refer to [Limitations](#limitations) for more details about the limitations of current work.
13 |
14 | # Examples
15 | If you can't load the videos, you can also directly download them from [here](examples/demos) and [here](examples/video).
16 | Or you can view them from our [Project Page](https://pbihao.github.io/projects/controlnext/index.html) or [BiliBili](https://www.bilibili.com/video/BV1wJYbebEE7/?buvid=YC4E03C93B119ADD4080B0958DE73F9DDCAC&from_spmid=dt.dt.video.0&is_story_h5=false&mid=y82Gz7uArS6jTQ6zuqJj3w%3D%3D&p=1&plat_id=114&share_from=ugc&share_medium=iphone&share_plat=ios&share_session_id=4E5549FC-0710-4030-BD2C-CDED80B46D08&share_source=WEIXIN&share_source=weixin&share_tag=s_i×tamp=1723123770&unique_k=XLZLhCq&up_id=176095810&vd_source=3791450598e16da25ecc2477fc7983db).
17 |
18 |
19 |
20 |
21 |
22 | |
23 |
24 |
25 | |
26 |
27 |
28 |
29 |
30 | |
31 |
32 |
33 | |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | # Base Model
44 |
45 | For the v2 version, we adopt the below operations to improve the performance:
46 | * We have collected a higher-quality dataset with higher resolution to train our model.
47 | * We have extended the training and inference batch frames to 24.
48 | * We have extended the video height and width to a resolution of 576 × 1024.
49 | * We conduct extensive continual training of SVD on human-related videos to enhance its ability to generate human-related content.
50 | * We adopt fp32.
51 | * We adopt the pose alignment during the inference following the related.
52 |
53 | # Inference
54 |
55 | 1. Clone our repository
56 | 2. `cd ControlNeXt-SVD-v2`
57 | 3. Download the pretrained weight into `pretrained/` from [here](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/v2). (More details please refer to [Base Model](#base-model))
58 | 4. Download the DWPose weights including the [dw-ll_ucoco_384](https://drive.google.com/file/d/12L8E2oAgZy4VACGSK9RaZBZrfgx7VTA2/view?usp=sharing) and [yolox_l](https://drive.google.com/file/d/1w9pXC8tT0p9ndMN-CArp1__b2GbzewWI/view?usp=sharing) into `pretrained/DWPose`. For more details, please refer to [DWPose](https://github.com/IDEA-Research/DWPose):
59 | ```
60 | ├───pretrained
61 | └───DWPose
62 | | │───dw-ll_ucoco_384.onnx
63 | | └───yolox_l.onnx
64 | |
65 | ├───unet.bin
66 | └───controlnet.bin
67 | ```
68 | 5. Run the scipt
69 |
70 | ```bash
71 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
72 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \
73 | --output_dir outputs \
74 | --max_frame_num 240 \
75 | --guidance_scale 3 \
76 | --batch_frames 24 \
77 | --sample_stride 2 \
78 | --overlap 6 \
79 | --height 1024 \
80 | --width 576 \
81 | --controlnext_path pretrained/controlnet.bin \
82 | --unet_path pretrained/unet.bin \
83 | --validation_control_video_path examples/video/02.mp4 \
84 | --ref_image_path examples/ref_imgs/01.jpeg
85 | ```
86 |
87 | > --pretrained_model_name_or_path : pretrained base model, we pretrain and fintune models based on [SVD-XT1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1)\
88 | > --controlnet_model_name_or_path : the model path of controlnet (a light weight module) \
89 | > --unet_model_name_or_path : the model path of unet \
90 | > --ref_image_path: the path to the reference image \
91 | > --overlap: The length of the overlapped frames for long-frame video generation. \
92 | > --sample_stride: The length of the sampled stride for the conditional controls. You can set it to `1` to make more smooth generation wihile requires more computation.
93 |
94 | 5. Face Enhancement (Optional,Recommand for bad faces)
95 |
96 | > Currently, the model is not specifically trained for IP consistency, as there are already many mature tools available. Additionally, alternatives like Animate Anyone also adopt such post-processing techniques.
97 |
98 | a. Clone [Face Fusion](https://github.com/facefusion/facefusion): \
99 | ```git clone https://github.com/facefusion/facefusion```
100 |
101 | b. Ensure to enter the directory:\
102 | ```cd facefusion```
103 |
104 | c. Install facefusion (Recommand create a new virtual environment using conda to avoid conflicts):\
105 | ```python install.py```
106 |
107 | d. Run the command:
108 | ```
109 | python run.py \
110 | -s ../outputs/demo.jpg \
111 | -t ../outputs/demo.mp4 \
112 | -o ../outputs/out.mp4 \
113 | --headless \
114 | --execution-providers cuda \
115 | --face-selector-mode one
116 | ```
117 |
118 | > -s: the reference image \
119 | > -t: the path to the original video\
120 | > -o: the path to store the refined video\
121 | > --headless: no gui need\
122 | > --execution-providers cuda: use cuda for acceleration (If available, most the cpu is enough)
123 |
124 | # Advanced Performance
125 | In this section, we will delve into additional details and my own experiences to enhance video generation. These factors are algorithm-independent and unrelated to academia, yet crucial for achieving superior results. Many closely related works incorporate these strategies.
126 |
127 | ### Reference Image
128 |
129 | It is crucial to ensure that the reference image is clear and easily understandable, especially aligning the face of the reference with the pose.
130 |
131 |
132 | ### Face Enhencement
133 |
134 | Most related works utilize face enhancement as part of the post-processing. This is especially relevant when generating videos based on images of unfamiliar individuals, such as friends, who were not included in the base model's pretraining and are therefore unseen and OOD data.
135 |
136 | We recommand the [Facefusion](https://github.com/facefusion/facefusion
137 | ) for the post proct-processing. And please let us know if you have a better solution.
138 |
139 | Please refer to [Facefusion](https://github.com/facefusion/facefusion
140 | ) for more details.
141 |
142 | 
143 |
144 |
145 | ### Continuously Finetune
146 |
147 | To significantly enhance performance on a specific pose sequence, you can continuously fine-tune the model for just a few hundred steps.
148 |
149 | We will release the related fine-tuning code later.
150 |
151 | ### Pose Generation
152 |
153 | We adopt [DWPose](https://github.com/IDEA-Research/DWPose) for the pose generation, and follow the related work ([1](https://humanaigc.github.io/animate-anyone/), [2](https://tencent.github.io/MimicMotion/)) to align the pose.
154 |
155 | # Limitations
156 |
157 | ## IP Consistency
158 |
159 | We did not prioritize maintaining IP consistency during the development of the generation model and now rely on a helper model for face enhancement.
160 |
161 | However, additional training can be implemented to ensure IP consistency moving forward.
162 |
163 | This also leaves a possible direction for futher improvement.
164 |
165 | ## Base model
166 |
167 | The base model plays a crucial role in generating human features, particularly hands and faces. We encourage collaboration to improve the base model for enhanced human-related video generation.
168 |
169 | # TODO
170 |
171 | * Training and finetune code
172 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/dwpose/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/dwpose/__init__.py
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/dwpose/dwpose_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from .wholebody import Wholebody
7 |
8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 |
11 | class DWposeDetector:
12 | """
13 | A pose detect method for image-like data.
14 |
15 | Parameters:
16 | model_det: (str) serialized ONNX format model path,
17 | such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx
18 | model_pose: (str) serialized ONNX format model path,
19 | such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx
20 | device: (str) 'cpu' or 'cuda:{device_id}'
21 | """
22 | def __init__(self, model_det, model_pose, device='cpu'):
23 | self.args = model_det, model_pose, device
24 |
25 | def release_memory(self):
26 | if hasattr(self, 'pose_estimation'):
27 | del self.pose_estimation
28 | import gc; gc.collect()
29 |
30 | def __call__(self, oriImg):
31 | if not hasattr(self, 'pose_estimation'):
32 | self.pose_estimation = Wholebody(*self.args)
33 |
34 | oriImg = oriImg.copy()
35 | H, W, C = oriImg.shape
36 | with torch.no_grad():
37 | candidate, score = self.pose_estimation(oriImg)
38 | nums, _, locs = candidate.shape
39 | candidate[..., 0] /= float(W)
40 | candidate[..., 1] /= float(H)
41 | body = candidate[:, :18].copy()
42 | body = body.reshape(nums * 18, locs)
43 | subset = score[:, :18].copy()
44 | for i in range(len(subset)):
45 | for j in range(len(subset[i])):
46 | if subset[i][j] > 0.3:
47 | subset[i][j] = int(18 * i + j)
48 | else:
49 | subset[i][j] = -1
50 |
51 | # un_visible = subset < 0.3
52 | # candidate[un_visible] = -1
53 |
54 | # foot = candidate[:, 18:24]
55 |
56 | faces = candidate[:, 24:92]
57 |
58 | hands = candidate[:, 92:113]
59 | hands = np.vstack([hands, candidate[:, 113:]])
60 |
61 | faces_score = score[:, 24:92]
62 | hands_score = np.vstack([score[:, 92:113], score[:, 113:]])
63 |
64 | bodies = dict(candidate=body, subset=subset, score=score[:, :18])
65 | pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)
66 |
67 | return pose
68 |
69 | dwpose_detector = DWposeDetector(
70 | model_det="pretrained/DWPose/yolox_l.onnx",
71 | model_pose="pretrained/DWPose/dw-ll_ucoco_384.onnx",
72 | device=device)
73 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/dwpose/onnxdet.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def nms(boxes, scores, nms_thr):
6 | """Single class NMS implemented in Numpy.
7 |
8 | Args:
9 | boxes (np.ndarray): shape=(N,4); N is number of boxes
10 | scores (np.ndarray): the score of bboxes
11 | nms_thr (float): the threshold in NMS
12 |
13 | Returns:
14 | List[int]: output bbox ids
15 | """
16 | x1 = boxes[:, 0]
17 | y1 = boxes[:, 1]
18 | x2 = boxes[:, 2]
19 | y2 = boxes[:, 3]
20 |
21 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
22 | order = scores.argsort()[::-1]
23 |
24 | keep = []
25 | while order.size > 0:
26 | i = order[0]
27 | keep.append(i)
28 | xx1 = np.maximum(x1[i], x1[order[1:]])
29 | yy1 = np.maximum(y1[i], y1[order[1:]])
30 | xx2 = np.minimum(x2[i], x2[order[1:]])
31 | yy2 = np.minimum(y2[i], y2[order[1:]])
32 |
33 | w = np.maximum(0.0, xx2 - xx1 + 1)
34 | h = np.maximum(0.0, yy2 - yy1 + 1)
35 | inter = w * h
36 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
37 |
38 | inds = np.where(ovr <= nms_thr)[0]
39 | order = order[inds + 1]
40 |
41 | return keep
42 |
43 | def multiclass_nms(boxes, scores, nms_thr, score_thr):
44 | """Multiclass NMS implemented in Numpy. Class-aware version.
45 |
46 | Args:
47 | boxes (np.ndarray): shape=(N,4); N is number of boxes
48 | scores (np.ndarray): the score of bboxes
49 | nms_thr (float): the threshold in NMS
50 | score_thr (float): the threshold of cls score
51 |
52 | Returns:
53 | np.ndarray: outputs bboxes coordinate
54 | """
55 | final_dets = []
56 | num_classes = scores.shape[1]
57 | for cls_ind in range(num_classes):
58 | cls_scores = scores[:, cls_ind]
59 | valid_score_mask = cls_scores > score_thr
60 | if valid_score_mask.sum() == 0:
61 | continue
62 | else:
63 | valid_scores = cls_scores[valid_score_mask]
64 | valid_boxes = boxes[valid_score_mask]
65 | keep = nms(valid_boxes, valid_scores, nms_thr)
66 | if len(keep) > 0:
67 | cls_inds = np.ones((len(keep), 1)) * cls_ind
68 | dets = np.concatenate(
69 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
70 | )
71 | final_dets.append(dets)
72 | if len(final_dets) == 0:
73 | return None
74 | return np.concatenate(final_dets, 0)
75 |
76 | def demo_postprocess(outputs, img_size, p6=False):
77 | grids = []
78 | expanded_strides = []
79 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
80 |
81 | hsizes = [img_size[0] // stride for stride in strides]
82 | wsizes = [img_size[1] // stride for stride in strides]
83 |
84 | for hsize, wsize, stride in zip(hsizes, wsizes, strides):
85 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
86 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
87 | grids.append(grid)
88 | shape = grid.shape[:2]
89 | expanded_strides.append(np.full((*shape, 1), stride))
90 |
91 | grids = np.concatenate(grids, 1)
92 | expanded_strides = np.concatenate(expanded_strides, 1)
93 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
94 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
95 |
96 | return outputs
97 |
98 | def preprocess(img, input_size, swap=(2, 0, 1)):
99 | if len(img.shape) == 3:
100 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
101 | else:
102 | padded_img = np.ones(input_size, dtype=np.uint8) * 114
103 |
104 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
105 | resized_img = cv2.resize(
106 | img,
107 | (int(img.shape[1] * r), int(img.shape[0] * r)),
108 | interpolation=cv2.INTER_LINEAR,
109 | ).astype(np.uint8)
110 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
111 |
112 | padded_img = padded_img.transpose(swap)
113 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
114 | return padded_img, r
115 |
116 | def inference_detector(session, oriImg):
117 | """run human detect
118 | """
119 | input_shape = (640,640)
120 | img, ratio = preprocess(oriImg, input_shape)
121 |
122 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
123 | output = session.run(None, ort_inputs)
124 | predictions = demo_postprocess(output[0], input_shape)[0]
125 |
126 | boxes = predictions[:, :4]
127 | scores = predictions[:, 4:5] * predictions[:, 5:]
128 |
129 | boxes_xyxy = np.ones_like(boxes)
130 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
131 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
132 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
133 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
134 | boxes_xyxy /= ratio
135 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
136 | if dets is not None:
137 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
138 | isscore = final_scores>0.3
139 | iscat = final_cls_inds == 0
140 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
141 | final_boxes = final_boxes[isbbox]
142 | else:
143 | final_boxes = np.array([])
144 |
145 | return final_boxes
146 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/dwpose/preprocess.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import decord
3 | import numpy as np
4 |
5 | from .util import draw_pose
6 | from .dwpose_detector import dwpose_detector as dwprocessor
7 |
8 |
9 | def get_video_pose(
10 | video_path: str,
11 | ref_image: np.ndarray,
12 | max_frame_num = None,
13 | sample_stride: int=1):
14 | """preprocess ref image pose and video pose
15 |
16 | Args:
17 | video_path (str): video pose path
18 | ref_image (np.ndarray): reference image
19 | sample_stride (int, optional): Defaults to 1.
20 |
21 | Returns:
22 | np.ndarray: sequence of video pose
23 | """
24 | # select ref-keypoint from reference pose for pose rescale
25 | ref_pose = dwprocessor(ref_image)
26 | ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
27 | ref_keypoint_id = [i for i in ref_keypoint_id \
28 | if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0]
29 | ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id]
30 |
31 | height, width, _ = ref_image.shape
32 |
33 | # read input video
34 | vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
35 | # sample_stride *= max(1, int(vr.get_avg_fps() / 24))
36 | if max_frame_num is None:
37 | max_frame_num = len(vr)
38 | else:
39 | max_frame_num = min(len(vr), max_frame_num * sample_stride)
40 |
41 | frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy()
42 | detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc="DWPose")]
43 | dwprocessor.release_memory()
44 |
45 | detected_bodies = np.stack(
46 | [p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:,
47 | ref_keypoint_id]
48 | # compute linear-rescale params
49 | ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1)
50 | fh, fw, _ = vr[0].shape
51 | ax = ay / (fh / fw / height * width)
52 | bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax)
53 | a = np.array([ax, ay])
54 | b = np.array([bx, by])
55 | output_pose = []
56 | # pose rescale
57 | for detected_pose in detected_poses:
58 | detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b
59 | detected_pose['faces'] = detected_pose['faces'] * a + b
60 | detected_pose['hands'] = detected_pose['hands'] * a + b
61 | im = draw_pose(detected_pose, height, width)
62 | output_pose.append(np.array(im))
63 | return np.stack(output_pose)
64 |
65 |
66 | def get_image_pose(ref_image):
67 | """process image pose
68 |
69 | Args:
70 | ref_image (np.ndarray): reference image pixel value
71 |
72 | Returns:
73 | np.ndarray: pose visual image in RGB-mode
74 | """
75 | height, width, _ = ref_image.shape
76 | ref_pose = dwprocessor(ref_image)
77 | pose_img = draw_pose(ref_pose, height, width)
78 | return np.array(pose_img)
79 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/dwpose/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import matplotlib
4 | import cv2
5 |
6 |
7 | eps = 0.01
8 |
9 | def alpha_blend_color(color, alpha):
10 | """blend color according to point conf
11 | """
12 | return [int(c * alpha) for c in color]
13 |
14 | def draw_bodypose(canvas, candidate, subset, score):
15 | H, W, C = canvas.shape
16 | candidate = np.array(candidate)
17 | subset = np.array(subset)
18 |
19 | stickwidth = 4
20 |
21 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
22 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
23 | [1, 16], [16, 18], [3, 17], [6, 18]]
24 |
25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
28 |
29 | for i in range(17):
30 | for n in range(len(subset)):
31 | index = subset[n][np.array(limbSeq[i]) - 1]
32 | conf = score[n][np.array(limbSeq[i]) - 1]
33 | if conf[0] < 0.3 or conf[1] < 0.3:
34 | continue
35 | Y = candidate[index.astype(int), 0] * float(W)
36 | X = candidate[index.astype(int), 1] * float(H)
37 | mX = np.mean(X)
38 | mY = np.mean(Y)
39 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
40 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
41 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
42 | cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1]))
43 |
44 | canvas = (canvas * 0.6).astype(np.uint8)
45 |
46 | for i in range(18):
47 | for n in range(len(subset)):
48 | index = int(subset[n][i])
49 | if index == -1:
50 | continue
51 | x, y = candidate[index][0:2]
52 | conf = score[n][i]
53 | x = int(x * W)
54 | y = int(y * H)
55 | cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1)
56 |
57 | return canvas
58 |
59 | def draw_handpose(canvas, all_hand_peaks, all_hand_scores):
60 | H, W, C = canvas.shape
61 |
62 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
63 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
64 |
65 | for peaks, scores in zip(all_hand_peaks, all_hand_scores):
66 |
67 | for ie, e in enumerate(edges):
68 | x1, y1 = peaks[e[0]]
69 | x2, y2 = peaks[e[1]]
70 | x1 = int(x1 * W)
71 | y1 = int(y1 * H)
72 | x2 = int(x2 * W)
73 | y2 = int(y2 * H)
74 | score = int(scores[e[0]] * scores[e[1]] * 255)
75 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
76 | cv2.line(canvas, (x1, y1), (x2, y2),
77 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2)
78 |
79 | for i, keyponit in enumerate(peaks):
80 | x, y = keyponit
81 | x = int(x * W)
82 | y = int(y * H)
83 | score = int(scores[i] * 255)
84 | if x > eps and y > eps:
85 | cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1)
86 | return canvas
87 |
88 | def draw_facepose(canvas, all_lmks, all_scores):
89 | H, W, C = canvas.shape
90 | for lmks, scores in zip(all_lmks, all_scores):
91 | for lmk, score in zip(lmks, scores):
92 | x, y = lmk
93 | x = int(x * W)
94 | y = int(y * H)
95 | conf = int(score * 255)
96 | if x > eps and y > eps:
97 | cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1)
98 | return canvas
99 |
100 | def draw_pose(pose, H, W, ref_w=2160):
101 | """vis dwpose outputs
102 |
103 | Args:
104 | pose (List): DWposeDetector outputs in dwpose_detector.py
105 | H (int): height
106 | W (int): width
107 | ref_w (int, optional) Defaults to 2160.
108 |
109 | Returns:
110 | np.ndarray: image pixel value in RGB mode
111 | """
112 | bodies = pose['bodies']
113 | faces = pose['faces']
114 | hands = pose['hands']
115 | candidate = bodies['candidate']
116 | subset = bodies['subset']
117 |
118 | sz = min(H, W)
119 | sr = (ref_w / sz) if sz != ref_w else 1
120 |
121 | ########################################## create zero canvas ##################################################
122 | canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8)
123 |
124 | ########################################### draw body pose #####################################################
125 | canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score'])
126 |
127 | ########################################### draw hand pose #####################################################
128 | canvas = draw_handpose(canvas, hands, pose['hands_score'])
129 |
130 | ########################################### draw face pose #####################################################
131 | canvas = draw_facepose(canvas, faces, pose['faces_score'])
132 |
133 | return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
134 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/dwpose/wholebody.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import onnxruntime as ort
3 |
4 | from .onnxdet import inference_detector
5 | from .onnxpose import inference_pose
6 |
7 |
8 | class Wholebody:
9 | """detect human pose by dwpose
10 | """
11 | def __init__(self, model_det, model_pose, device="cpu"):
12 | providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
13 | provider_options = None if device == 'cpu' else [{'device_id': 0}]
14 |
15 | self.session_det = ort.InferenceSession(
16 | path_or_bytes=model_det, providers=providers, provider_options=provider_options
17 | )
18 | self.session_pose = ort.InferenceSession(
19 | path_or_bytes=model_pose, providers=providers, provider_options=provider_options
20 | )
21 |
22 | def __call__(self, oriImg):
23 | """call to process dwpose-detect
24 |
25 | Args:
26 | oriImg (np.ndarray): detected image
27 |
28 | """
29 | det_result = inference_detector(self.session_det, oriImg)
30 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
31 |
32 | keypoints_info = np.concatenate(
33 | (keypoints, scores[..., None]), axis=-1)
34 | # compute neck joint
35 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
36 | # neck score when visualizing pred
37 | neck[:, 2:4] = np.logical_and(
38 | keypoints_info[:, 5, 2:4] > 0.3,
39 | keypoints_info[:, 6, 2:4] > 0.3).astype(int)
40 | new_keypoints_info = np.insert(
41 | keypoints_info, 17, neck, axis=1)
42 | mmpose_idx = [
43 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
44 | ]
45 | openpose_idx = [
46 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
47 | ]
48 | new_keypoints_info[:, openpose_idx] = \
49 | new_keypoints_info[:, mmpose_idx]
50 | keypoints_info = new_keypoints_info
51 |
52 | keypoints, scores = keypoints_info[
53 | ..., :2], keypoints_info[..., 2]
54 |
55 | return keypoints, scores
56 |
57 |
58 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/demos/01-1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/demos/01-1.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/demos/02-1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/demos/02-1.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/demos/03-1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/demos/03-1.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/demos/04-1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/demos/04-1.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/facefusion/facefusion.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/facefusion/facefusion.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/ref_imgs/01.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/ref_imgs/01.jpeg
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/ref_imgs/02.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/ref_imgs/02.jpeg
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/ref_imgs/03.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/ref_imgs/03.jpeg
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/ref_imgs/04.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/ref_imgs/04.jpeg
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/video/01.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/video/01.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/examples/video/02.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/video/02.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/models/controlnext_vid_svd.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from diffusers.configuration_utils import ConfigMixin, register_to_config
7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
8 | from diffusers.models.modeling_utils import ModelMixin
9 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D
10 |
11 |
12 | class ControlNeXtSDVModel(ModelMixin, ConfigMixin):
13 | _supports_gradient_checkpointing = True
14 |
15 | @register_to_config
16 | def __init__(
17 | self,
18 | time_embed_dim = 256,
19 | in_channels = [128, 128],
20 | out_channels = [128, 256],
21 | groups = [4, 8]
22 | ):
23 | super().__init__()
24 |
25 | self.time_proj = Timesteps(128, True, downscale_freq_shift=0)
26 | self.time_embedding = TimestepEmbedding(128, time_embed_dim)
27 | self.embedding = nn.Sequential(
28 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
29 | nn.GroupNorm(2, 64),
30 | nn.ReLU(),
31 | nn.Conv2d(64, 64, kernel_size=3),
32 | nn.GroupNorm(2, 64),
33 | nn.ReLU(),
34 | nn.Conv2d(64, 128, kernel_size=3),
35 | nn.GroupNorm(2, 128),
36 | nn.ReLU(),
37 | )
38 |
39 | self.down_res = nn.ModuleList()
40 | self.down_sample = nn.ModuleList()
41 | for i in range(len(in_channels)):
42 | self.down_res.append(
43 | ResnetBlock2D(
44 | in_channels=in_channels[i],
45 | out_channels=out_channels[i],
46 | temb_channels=time_embed_dim,
47 | groups=groups[i]
48 | ),
49 | )
50 | self.down_sample.append(
51 | Downsample2D(
52 | out_channels[i],
53 | use_conv=True,
54 | out_channels=out_channels[i],
55 | padding=1,
56 | name="op",
57 | )
58 | )
59 |
60 | self.mid_convs = nn.ModuleList()
61 | self.mid_convs.append(nn.Sequential(
62 | nn.Conv2d(
63 | in_channels=out_channels[-1],
64 | out_channels=out_channels[-1],
65 | kernel_size=3,
66 | stride=1,
67 | padding=1
68 | ),
69 | nn.ReLU(),
70 | nn.GroupNorm(8, out_channels[-1]),
71 | nn.Conv2d(
72 | in_channels=out_channels[-1],
73 | out_channels=out_channels[-1],
74 | kernel_size=3,
75 | stride=1,
76 | padding=1
77 | ),
78 | nn.GroupNorm(8, out_channels[-1]),
79 | ))
80 | self.mid_convs.append(
81 | nn.Conv2d(
82 | in_channels=out_channels[-1],
83 | out_channels=320,
84 | kernel_size=1,
85 | stride=1,
86 | ))
87 |
88 | self.scale = 1.
89 |
90 | def _set_gradient_checkpointing(self, module, value=False):
91 | if hasattr(module, "gradient_checkpointing"):
92 | module.gradient_checkpointing = value
93 |
94 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
95 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
96 | """
97 | Sets the attention processor to use [feed forward
98 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
99 |
100 | Parameters:
101 | chunk_size (`int`, *optional*):
102 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
103 | over each tensor of dim=`dim`.
104 | dim (`int`, *optional*, defaults to `0`):
105 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
106 | or dim=1 (sequence length).
107 | """
108 | if dim not in [0, 1]:
109 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
110 |
111 | # By default chunk size is 1
112 | chunk_size = chunk_size or 1
113 |
114 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
115 | if hasattr(module, "set_chunk_feed_forward"):
116 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
117 |
118 | for child in module.children():
119 | fn_recursive_feed_forward(child, chunk_size, dim)
120 |
121 | for module in self.children():
122 | fn_recursive_feed_forward(module, chunk_size, dim)
123 |
124 | def forward(
125 | self,
126 | sample: torch.FloatTensor,
127 | timestep: Union[torch.Tensor, float, int],
128 | ):
129 |
130 | timesteps = timestep
131 | if not torch.is_tensor(timesteps):
132 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
133 | # This would be a good case for the `match` statement (Python 3.10+)
134 | is_mps = sample.device.type == "mps"
135 | if isinstance(timestep, float):
136 | dtype = torch.float32 if is_mps else torch.float64
137 | else:
138 | dtype = torch.int32 if is_mps else torch.int64
139 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
140 | elif len(timesteps.shape) == 0:
141 | timesteps = timesteps[None].to(sample.device)
142 |
143 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
144 | batch_size, num_frames = sample.shape[:2]
145 | timesteps = timesteps.expand(batch_size)
146 |
147 | t_emb = self.time_proj(timesteps)
148 |
149 | # `Timesteps` does not contain any weights and will always return f32 tensors
150 | # but time_embedding might actually be running in fp16. so we need to cast here.
151 | # there might be better ways to encapsulate this.
152 | t_emb = t_emb.to(dtype=sample.dtype)
153 |
154 | emb_batch = self.time_embedding(t_emb)
155 |
156 | # Flatten the batch and frames dimensions
157 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
158 | sample = sample.flatten(0, 1)
159 | # Repeat the embeddings num_video_frames times
160 | # emb: [batch, channels] -> [batch * frames, channels]
161 | emb = emb_batch.repeat_interleave(num_frames, dim=0)
162 |
163 | sample = self.embedding(sample)
164 |
165 | for res, downsample in zip(self.down_res, self.down_sample):
166 | sample = res(sample, emb)
167 | sample = downsample(sample, emb)
168 |
169 | sample = self.mid_convs[0](sample) + sample
170 | sample = self.mid_convs[1](sample)
171 |
172 | return {
173 | 'output': sample,
174 | 'scale': self.scale,
175 | }
176 |
177 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/run_controlnext.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt
6 | from models.controlnext_vid_svd import ControlNeXtSDVModel
7 | from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel
8 | from transformers import CLIPVisionModelWithProjection
9 | import re
10 | from diffusers import AutoencoderKLTemporalDecoder
11 | from moviepy.editor import ImageSequenceClip
12 | from decord import VideoReader
13 | import argparse
14 | from safetensors.torch import load_file
15 | from utils.pre_process import preprocess
16 |
17 |
18 | def write_mp4(video_path, samples, fps=14, audio_bitrate="192k"):
19 | clip = ImageSequenceClip(samples, fps=fps)
20 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate,
21 | ffmpeg_params=["-crf", "18", "-preset", "slow"])
22 |
23 | def save_vid_side_by_side(batch_output, validation_control_images, output_folder, fps):
24 | # Helper function to convert tensors to PIL images and save as GIF
25 | flattened_batch_output = [img for sublist in batch_output for img in sublist]
26 | video_path = output_folder+'/test_1.mp4'
27 | final_images = []
28 | outputs = []
29 | # Helper function to concatenate images horizontally
30 | def get_concat_h(im1, im2):
31 | dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height)))
32 | dst.paste(im1, (0, 0))
33 | dst.paste(im2, (im1.width, 0))
34 | return dst
35 | for image_list in zip(validation_control_images, flattened_batch_output):
36 | predict_img = image_list[1].resize(image_list[0].size)
37 | result = get_concat_h(image_list[0], predict_img)
38 | final_images.append(np.array(result))
39 | outputs.append(np.array(predict_img))
40 | write_mp4(video_path, final_images, fps=fps)
41 |
42 | output_path = output_folder + "/output.mp4"
43 | write_mp4(output_path, outputs, fps=fps)
44 |
45 |
46 | def load_images_from_folder_to_pil(folder):
47 | images = []
48 | valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed
49 |
50 | # Function to extract frame number from the filename
51 | def frame_number(filename):
52 | # First, try the pattern 'frame_x_7fps'
53 | new_pattern_match = re.search(r'frame_(\d+)_7fps', filename)
54 | if new_pattern_match:
55 | return int(new_pattern_match.group(1))
56 | # If the new pattern is not found, use the original digit extraction method
57 | matches = re.findall(r'\d+', filename)
58 | if matches:
59 | if matches[-1] == '0000' and len(matches) > 1:
60 | return int(matches[-2]) # Return the second-to-last sequence if the last is '0000'
61 | return int(matches[-1]) # Otherwise, return the last sequence
62 | return float('inf') # Return 'inf'
63 |
64 | # Sorting files based on frame number
65 | sorted_files = sorted(os.listdir(folder), key=frame_number)
66 | # Load images in sorted order
67 | for filename in sorted_files:
68 | ext = os.path.splitext(filename)[1].lower()
69 | if ext in valid_extensions:
70 | img = Image.open(os.path.join(folder, filename)).convert('RGB')
71 | images.append(img)
72 |
73 | return images
74 |
75 |
76 | def load_images_from_video_to_pil(video_path):
77 | images = []
78 |
79 | vr = VideoReader(video_path)
80 | length = len(vr)
81 |
82 | for idx in range(length):
83 | frame = vr[idx].asnumpy()
84 | images.append(Image.fromarray(frame))
85 | return images
86 |
87 |
88 | def parse_args():
89 | parser = argparse.ArgumentParser(
90 | description="Script to train Stable Diffusion XL for InstructPix2Pix."
91 | )
92 |
93 | parser.add_argument(
94 | "--pretrained_model_name_or_path",
95 | type=str,
96 | default=None,
97 | required=True
98 | )
99 |
100 | parser.add_argument(
101 | "--validation_control_images_folder",
102 | type=str,
103 | default=None,
104 | required=False,
105 | )
106 |
107 | parser.add_argument(
108 | "--validation_control_video_path",
109 | type=str,
110 | default=None,
111 | required=False,
112 | )
113 |
114 | parser.add_argument(
115 | "--output_dir",
116 | type=str,
117 | default=None,
118 | required=True
119 | )
120 |
121 | parser.add_argument(
122 | "--height",
123 | type=int,
124 | default=768,
125 | required=False
126 | )
127 |
128 | parser.add_argument(
129 | "--width",
130 | type=int,
131 | default=512,
132 | required=False
133 | )
134 |
135 | parser.add_argument(
136 | "--guidance_scale",
137 | type=float,
138 | default=2.,
139 | required=False
140 | )
141 |
142 | parser.add_argument(
143 | "--num_inference_steps",
144 | type=int,
145 | default=25,
146 | required=False
147 | )
148 |
149 |
150 | parser.add_argument(
151 | "--controlnext_path",
152 | type=str,
153 | default=None,
154 | required=True
155 | )
156 |
157 | parser.add_argument(
158 | "--unet_path",
159 | type=str,
160 | default=None,
161 | required=True
162 | )
163 |
164 | parser.add_argument(
165 | "--max_frame_num",
166 | type=int,
167 | default=50,
168 | required=False
169 | )
170 |
171 | parser.add_argument(
172 | "--ref_image_path",
173 | type=str,
174 | default=None,
175 | required=True
176 | )
177 |
178 | parser.add_argument(
179 | "--batch_frames",
180 | type=int,
181 | default=14,
182 | required=False
183 | )
184 |
185 | parser.add_argument(
186 | "--overlap",
187 | type=int,
188 | default=4,
189 | required=False
190 | )
191 |
192 | parser.add_argument(
193 | "--sample_stride",
194 | type=int,
195 | default=2,
196 | required=False
197 | )
198 |
199 | args = parser.parse_args()
200 | return args
201 |
202 |
203 | def load_tensor(tensor_path):
204 | if os.path.splitext(tensor_path)[1] == '.bin':
205 | return torch.load(tensor_path)
206 | elif os.path.splitext(tensor_path)[1] == ".safetensors":
207 | return load_file(tensor_path)
208 | else:
209 | print("without supported tensors")
210 | os._exit()
211 |
212 |
213 | # Main script
214 | if __name__ == "__main__":
215 | args = parse_args()
216 |
217 | assert (args.validation_control_images_folder is None) ^ (args.validation_control_video_path is None), "must and only one of [validation_control_images_folder, validation_control_video_path] should be given"
218 |
219 | unet = UNetSpatioTemporalConditionControlNeXtModel.from_pretrained(
220 | args.pretrained_model_name_or_path,
221 | subfolder="unet",
222 | low_cpu_mem_usage=True,
223 | )
224 | controlnext = ControlNeXtSDVModel()
225 | controlnext.load_state_dict(load_tensor(args.controlnext_path))
226 | unet.load_state_dict(load_tensor(args.unet_path), strict=False)
227 |
228 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(
229 | args.pretrained_model_name_or_path, subfolder="image_encoder")
230 | vae = AutoencoderKLTemporalDecoder.from_pretrained(
231 | args.pretrained_model_name_or_path, subfolder="vae")
232 |
233 | pipeline = StableVideoDiffusionPipelineControlNeXt.from_pretrained(
234 | args.pretrained_model_name_or_path,
235 | controlnext=controlnext,
236 | unet=unet,
237 | vae=vae,
238 | image_encoder=image_encoder)
239 | # pipeline.to(dtype=torch.float16)
240 | pipeline.enable_model_cpu_offload()
241 |
242 | os.makedirs(args.output_dir, exist_ok=True)
243 |
244 | # Inference and saving loop
245 | # ref_image = Image.open(args.ref_image_path).convert('RGB')
246 | # ref_image = ref_image.resize((args.width, args.height))
247 | # validation_control_images = [img.resize((args.width, args.height)) for img in validation_control_images]
248 |
249 | validation_control_images, ref_image = preprocess(args.validation_control_video_path, args.ref_image_path, width=args.width, height=args.height, max_frame_num=args.max_frame_num, sample_stride=args.sample_stride)
250 |
251 |
252 | final_result = []
253 | frames = args.batch_frames
254 | num_frames = min(args.max_frame_num, len(validation_control_images))
255 |
256 | for i in range(num_frames):
257 | validation_control_images[i] = Image.fromarray(np.array(validation_control_images[i]))
258 |
259 | video_frames = pipeline(
260 | ref_image,
261 | validation_control_images[:num_frames],
262 | decode_chunk_size=2,
263 | num_frames=num_frames,
264 | motion_bucket_id=127.0,
265 | fps=7,
266 | controlnext_cond_scale=1.0,
267 | width=args.width,
268 | height=args.height,
269 | min_guidance_scale=args.guidance_scale,
270 | max_guidance_scale=args.guidance_scale,
271 | frames_per_batch=frames,
272 | num_inference_steps=args.num_inference_steps,
273 | overlap=args.overlap).frames[0]
274 | final_result.append(video_frames)
275 |
276 | fps =VideoReader(args.validation_control_video_path).get_avg_fps() // args.sample_stride
277 |
278 | save_vid_side_by_side(
279 | final_result,
280 | validation_control_images[:num_frames],
281 | args.output_dir,
282 | fps=fps)
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/script.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \
2 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \
3 | --output_dir outputs \
4 | --max_frame_num 240 \
5 | --guidance_scale 3 \
6 | --batch_frames 24 \
7 | --sample_stride 2 \
8 | --overlap 6 \
9 | --height 1024 \
10 | --width 576 \
11 | --controlnext_path pretrained/controlnet.bin \
12 | --unet_path pretrained/unet.bin \
13 | --validation_control_video_path examples/video/02.mp4 \
14 | --ref_image_path examples/ref_imgs/01.jpeg
15 |
16 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD-v2/utils/pre_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | import math
5 | from omegaconf import OmegaConf
6 | from datetime import datetime
7 | from pathlib import Path
8 | from PIL import Image
9 | import numpy as np
10 | import torch.jit
11 | from torchvision.datasets.folder import pil_loader
12 | from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
13 | from torchvision.transforms.functional import to_pil_image
14 | from dwpose.preprocess import get_image_pose, get_video_pose
15 |
16 | ASPECT_RATIO = 9 / 16
17 |
18 | def preprocess(video_path, image_path, width=576, height=1024, sample_stride=2, max_frame_num=None):
19 | """preprocess ref image pose and video pose
20 |
21 | Args:
22 | video_path (str): input video pose path
23 | image_path (str): reference image path
24 | resolution (int, optional): Defaults to 576.
25 | sample_stride (int, optional): Defaults to 2.
26 | """
27 | image_pixels = pil_loader(image_path)
28 | image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
29 | h, w = image_pixels.shape[-2:]
30 | ############################ compute target h/w according to original aspect ratio ###############################
31 | # if h>w:
32 | # w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
33 | # else:
34 | # w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
35 | w_target, h_target = width, height
36 | h_w_ratio = float(h) / float(w)
37 | if h_w_ratio < h_target / w_target:
38 | h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
39 | else:
40 | h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
41 | image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
42 | image_pixels = center_crop(image_pixels, [h_target, w_target])
43 | image_pixels = image_pixels.permute((1, 2, 0)).numpy()
44 | ##################################### get image&video pose value #################################################
45 | image_pose = get_image_pose(image_pixels)
46 | video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride, max_frame_num=max_frame_num)
47 | pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
48 | # image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
49 | image_pixels = Image.fromarray(image_pixels)
50 | pose_pixels = [Image.fromarray(p.transpose((1,2,0))) for p in pose_pixels]
51 | # return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1
52 | return pose_pixels, image_pixels
53 |
54 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD/README.md:
--------------------------------------------------------------------------------
1 | # 🌀 ControlNeXt-SVD
2 |
3 | This is our implementation of ControlNeXt based on [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). It can be seen as an attempt to replicate the implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) with a more concise and efficient architecture.
4 |
5 | Compared to image generation, video generation poses significantly greater challenges. While direct training of the generation model using our method is feasible, we also employ various engineering strategies to enhance performance. Although they are irrespective of academic algorithms.
6 |
7 |
8 | > Please refer to [Examples](#examples) for further intuitive details.\
9 | > Please refer to [Base model](#base-model) for more details of our used base model. \
10 | > Please refer to [Inference](#inference) for more details regarding installation and inference.\
11 | > Please refer to [Advanced Performance](#advanced-performance) for more details to achieve a better performance.\
12 | > Please refer to [Limitations](#limitations) for more details about the limitations of current work.
13 |
14 | # Examples
15 | If you can't load the videos, you can also directly download them from [here](outputs).
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | # Base Model
26 |
27 | The base model's generation capability significantly influences video generation. Initially, we train the generation model using our method, which is based on [Stable Video Diffusion XT-1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). However, the original SVD model exhibits weaknesses in generating human features, particularly in the generation of hands and faces.
28 |
29 | Therefore, we initially conduct continuous pretraining of [SVD-XT1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) on our curated collection of human-related videos to improve its ability to generate human features. Subsequently, we fine-tune it for specific downstream tasks, i.e., generating dance videos guided by pose sequences.
30 |
31 | In this project, we release all our models including the base mode and the fine-tuned model. You can download them from:
32 | * [Fintuned Model](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/finetune): We fine-tune our own trained base model for the downstream task using our proposed method, incorporating only `50M` learnable parameters. For your convenience, we directly merge the pretrained base model with the fine-tuned parameters, and you can download this consolidated model.
33 | * [Continuously Pretrained Model](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/pretrained): We continuously pretrain the base model using our collected human-related data, based on [SVD-XT1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). This approach improves performance in generating human features, particularly for hands and faces. However, due to the complexity of human motion, it still faces challenges similar to SVD in preferring to generate static videos. Nonetheless, it excels in downstream tasks. We encourage more participation to further enhance the base model for generating human-related videos.
34 | * [Fintuned Parameters](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/learned_params): The parameters involved in the fine-tuning process. Or you can directly download `Fintuned Model`.
35 |
36 |
37 | # Inference
38 |
39 | 1. Clone our repository
40 | 2. `cd ControlNeXt-SVD`
41 | 3. Download the pretrained weight into `pretrained/` from [here](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/finetune). (More details please refer to [Base Model](#base-model))
42 | 4. Run the scipt
43 |
44 | ```python
45 | python run_controlnext.py \
46 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \
47 | --validation_control_video_path examples/pose/pose.mp4 \
48 | --output_dir outputs/tiktok \
49 | --controlnext_path pretrained/controlnet.bin \
50 | --unet_path pretrained/unet_fp16.bin \
51 | --ref_image_path examples/ref_imgs/tiktok.png
52 | ```
53 |
54 | > --pretrained_model_name_or_path : pretrained base model, we pretrain and fintune models based on [SVD-XT1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1)\
55 | > --controlnet_model_name_or_path : the model path of controlnet (a light weight module) \
56 | > --unet_model_name_or_path : the model path of unet
57 |
58 | 5. Face Enhancement (Optional,Recommand for bad faces)
59 |
60 | > Currently, the model is not specifically trained for IP consistency, as there are already many mature tools available. Additionally, alternatives like Animate Anyone also adopt such post-processing techniques.
61 |
62 | a. Clone [Face Fusion](https://github.com/facefusion/facefusion): \
63 | ```git clone https://github.com/facefusion/facefusion```
64 |
65 | b. Ensure to enter the directory:\
66 | ```cd facefusion```
67 |
68 | c. Install facefusion (Recommand create a new virtual environment using conda to avoid conflicts):\
69 | ```python install.py```
70 |
71 | d. Run the command:
72 | ```
73 | python run.py \
74 | -s ../outputs/collected/demo.jpg \
75 | -t ../outputs/collected/demo.mp4 \
76 | -o ../outputs/collected/out.mp4 \
77 | --headless \
78 | --execution-providers cuda \
79 | --face-selector-mode one
80 | ```
81 |
82 | > -s: the reference image \
83 | > -t: the path to the original video\
84 | > -o: the path to store the refined video\
85 | > --headless: no gui need\
86 | > --execution-providers cuda: use cuda for acceleration (If available, most the cpu is enough)
87 |
88 | # Advanced Performance
89 | In this section, we will delve into additional details and my own experiences to enhance video generation. These factors are algorithm-independent and unrelated to academia, yet crucial for achieving superior results. Many closely related works incorporate these strategies.
90 |
91 | ### Reference Image
92 |
93 | It is crucial to ensure that the reference image is clear and easily understandable, especially aligning the face of the reference with the pose.
94 |
95 |
96 | ### Face Enhencement
97 |
98 | Most related works utilize face enhancement as part of the post-processing. This is especially relevant when generating videos based on images of unfamiliar individuals, such as friends, who were not included in the base model's pretraining and are therefore unseen and OOD data.
99 |
100 | We recommand the [Facefusion](https://github.com/facefusion/facefusion
101 | ) for the post proct-processing. And please let us know if you have a better solution.
102 |
103 | Please refer to [Facefusion](https://github.com/facefusion/facefusion
104 | ) for more details.
105 |
106 | 
107 |
108 |
109 | ### Continuously Finetune
110 |
111 | To significantly enhance performance on a specific pose sequence, you can continuously fine-tune the model for just a few hundred steps.
112 |
113 | We will release the related fine-tuning code later.
114 |
115 | ### Pose Generation
116 |
117 | We adopt [DWPose](https://github.com/IDEA-Research/DWPose) for the pose generation.
118 |
119 | # Limitations
120 |
121 | ## IP Consistency
122 |
123 | We did not prioritize maintaining IP consistency during the development of the generation model and now rely on a helper model for face enhancement.
124 |
125 | However, additional training can be implemented to ensure IP consistency moving forward.
126 |
127 | This also leaves a possible direction for futher improvement.
128 |
129 | ## Base model
130 |
131 | The base model plays a crucial role in generating human features, particularly hands and faces. We encourage collaboration to improve the base model for enhanced human-related video generation.
132 |
133 | # TODO
134 |
135 | * Training and finetune code
136 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD/examples/facefusion/facefusion.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/examples/facefusion/facefusion.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SVD/examples/pose/pose.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/examples/pose/pose.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD/examples/ref_imgs/spiderman.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/examples/ref_imgs/spiderman.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SVD/examples/ref_imgs/tiktok.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/examples/ref_imgs/tiktok.png
--------------------------------------------------------------------------------
/ControlNeXt-SVD/outputs/chair/chair.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/chair/chair.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD/outputs/collected/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/collected/demo.jpg
--------------------------------------------------------------------------------
/ControlNeXt-SVD/outputs/collected/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/collected/demo.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD/outputs/collected/out2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/collected/out2.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD/outputs/spiderman/spiderman.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/spiderman/spiderman.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD/outputs/star/star.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/star/star.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD/outputs/tiktok/tiktok.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/tiktok/tiktok.mp4
--------------------------------------------------------------------------------
/ControlNeXt-SVD/run_controlnext.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt
6 | from models.controlnext_vid_svd import ControlNeXtSDVModel
7 | from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel
8 | from transformers import CLIPVisionModelWithProjection
9 | import re
10 | from diffusers import AutoencoderKLTemporalDecoder
11 | from moviepy.editor import ImageSequenceClip
12 | from decord import VideoReader
13 | import argparse
14 | from safetensors.torch import load_file
15 |
16 |
17 | def write_mp4(video_path, samples, fps=14):
18 | clip = ImageSequenceClip(samples, fps=fps)
19 | clip.write_videofile(video_path, audio_codec="aac")
20 |
21 | def save_vid_side_by_side(batch_output, validation_control_images, output_folder, fps):
22 | # Helper function to convert tensors to PIL images and save as GIF
23 | flattened_batch_output = [img for sublist in batch_output for img in sublist]
24 | video_path = output_folder+'/test_1.mp4'
25 | final_images = []
26 | outputs = []
27 | # Helper function to concatenate images horizontally
28 | def get_concat_h(im1, im2):
29 | dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height)))
30 | dst.paste(im1, (0, 0))
31 | dst.paste(im2, (im1.width, 0))
32 | return dst
33 | for image_list in zip(validation_control_images, flattened_batch_output):
34 | predict_img = image_list[1].resize(image_list[0].size)
35 | result = get_concat_h(image_list[0], predict_img)
36 | final_images.append(np.array(result))
37 | outputs.append(np.array(predict_img))
38 | write_mp4(video_path, final_images, fps=fps)
39 |
40 | output_path = output_folder + "/output.mp4"
41 | write_mp4(output_path, outputs, fps=fps)
42 |
43 |
44 | def load_images_from_folder_to_pil(folder):
45 | images = []
46 | valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed
47 |
48 | # Function to extract frame number from the filename
49 | def frame_number(filename):
50 | # First, try the pattern 'frame_x_7fps'
51 | new_pattern_match = re.search(r'frame_(\d+)_7fps', filename)
52 | if new_pattern_match:
53 | return int(new_pattern_match.group(1))
54 | # If the new pattern is not found, use the original digit extraction method
55 | matches = re.findall(r'\d+', filename)
56 | if matches:
57 | if matches[-1] == '0000' and len(matches) > 1:
58 | return int(matches[-2]) # Return the second-to-last sequence if the last is '0000'
59 | return int(matches[-1]) # Otherwise, return the last sequence
60 | return float('inf') # Return 'inf'
61 |
62 | # Sorting files based on frame number
63 | sorted_files = sorted(os.listdir(folder), key=frame_number)
64 | # Load images in sorted order
65 | for filename in sorted_files:
66 | ext = os.path.splitext(filename)[1].lower()
67 | if ext in valid_extensions:
68 | img = Image.open(os.path.join(folder, filename)).convert('RGB')
69 | images.append(img)
70 |
71 | return images
72 |
73 |
74 | def load_images_from_video_to_pil(video_path):
75 | images = []
76 |
77 | vr = VideoReader(video_path)
78 | length = len(vr)
79 |
80 | for idx in range(length):
81 | frame = vr[idx].asnumpy()
82 | images.append(Image.fromarray(frame))
83 | return images
84 |
85 |
86 | def parse_args():
87 | parser = argparse.ArgumentParser(
88 | description="Script to train Stable Diffusion XL for InstructPix2Pix."
89 | )
90 |
91 | parser.add_argument(
92 | "--pretrained_model_name_or_path",
93 | type=str,
94 | default=None,
95 | required=True
96 | )
97 |
98 | parser.add_argument(
99 | "--validation_control_images_folder",
100 | type=str,
101 | default=None,
102 | required=False,
103 | )
104 |
105 | parser.add_argument(
106 | "--validation_control_video_path",
107 | type=str,
108 | default=None,
109 | required=False,
110 | )
111 |
112 | parser.add_argument(
113 | "--output_dir",
114 | type=str,
115 | default=None,
116 | required=True
117 | )
118 |
119 | parser.add_argument(
120 | "--height",
121 | type=int,
122 | default=768,
123 | required=False
124 | )
125 |
126 | parser.add_argument(
127 | "--width",
128 | type=int,
129 | default=512,
130 | required=False
131 | )
132 |
133 | parser.add_argument(
134 | "--guidance_scale",
135 | type=float,
136 | default=3.5,
137 | required=False
138 | )
139 |
140 | parser.add_argument(
141 | "--num_inference_steps",
142 | type=int,
143 | default=25,
144 | required=False
145 | )
146 |
147 | parser.add_argument(
148 | "--fps",
149 | type=int,
150 | default=14,
151 | required=False
152 | )
153 |
154 | parser.add_argument(
155 | "--controlnext_path",
156 | type=str,
157 | default=None,
158 | required=True
159 | )
160 |
161 | parser.add_argument(
162 | "--unet_path",
163 | type=str,
164 | default=None,
165 | required=True
166 | )
167 |
168 | parser.add_argument(
169 | "--ref_image_path",
170 | type=str,
171 | default=None,
172 | required=True
173 | )
174 |
175 | args = parser.parse_args()
176 | return args
177 |
178 |
179 | def load_tensor(tensor_path):
180 | if os.path.splitext(tensor_path)[1] == '.bin':
181 | return torch.load(tensor_path)
182 | elif os.path.splitext(tensor_path)[1] == ".safetensors":
183 | return load_file(tensor_path)
184 | else:
185 | print("without supported tensors")
186 | os._exit()
187 |
188 |
189 | # Main script
190 | if __name__ == "__main__":
191 | args = parse_args()
192 |
193 | assert (args.validation_control_images_folder is None) ^ (args.validation_control_video_path is None), "must and only one of [validation_control_images_folder, validation_control_video_path] should be given"
194 | if args.validation_control_images_folder is not None:
195 | validation_control_images = load_images_from_folder_to_pil(args.validation_control_images_folder)
196 | else:
197 | validation_control_images = load_images_from_video_to_pil(args.validation_control_video_path)
198 |
199 | unet = UNetSpatioTemporalConditionControlNeXtModel.from_pretrained(
200 | args.pretrained_model_name_or_path,
201 | subfolder="unet",
202 | low_cpu_mem_usage=True,
203 | variant="fp16",
204 | )
205 | controlnext = ControlNeXtSDVModel()
206 | controlnext.load_state_dict(load_tensor(args.controlnext_path))
207 | unet.load_state_dict(load_tensor(args.unet_path), strict=False)
208 |
209 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(
210 | args.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16")
211 | vae = AutoencoderKLTemporalDecoder.from_pretrained(
212 | args.pretrained_model_name_or_path, subfolder="vae", variant="fp16")
213 |
214 | pipeline = StableVideoDiffusionPipelineControlNeXt.from_pretrained(
215 | args.pretrained_model_name_or_path,
216 | controlnext=controlnext,
217 | unet=unet,
218 | vae=vae,
219 | image_encoder=image_encoder)
220 | pipeline.to(dtype=torch.float16)
221 | pipeline.enable_model_cpu_offload()
222 |
223 | os.makedirs(args.output_dir, exist_ok=True)
224 |
225 | # Inference and saving loop
226 | final_result = []
227 | ref_image = Image.open(args.ref_image_path).convert('RGB')
228 | frames = 14
229 | num_frames = len(validation_control_images)
230 |
231 | video_frames = pipeline(
232 | ref_image,
233 | validation_control_images[:num_frames],
234 | decode_chunk_size=4,
235 | num_frames=num_frames,
236 | motion_bucket_id=127.0,
237 | fps=7,
238 | controlnext_cond_scale=1.0,
239 | width=args.width,
240 | height=args.height,
241 | min_guidance_scale=args.guidance_scale,
242 | max_guidance_scale=args.guidance_scale,
243 | frames_per_batch=frames,
244 | num_inference_steps=args.num_inference_steps,
245 | overlap=4).frames[0]
246 | final_result.append(video_frames)
247 |
248 | save_vid_side_by_side(
249 | final_result,
250 | validation_control_images[:num_frames],
251 | args.output_dir,
252 | fps=args.fps)
253 |
--------------------------------------------------------------------------------
/ControlNeXt-SVD/script.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | python run_controlnext.py \
4 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \
5 | --validation_control_video_path examples/pose/pose.mp4 \
6 | --output_dir outputs/tiktok \
7 | --controlnext_path pretrained/controlnet.bin \
8 | --unet_path pretrained/unet_fp16.bin \
9 | --ref_image_path examples/ref_imgs/tiktok.png
10 |
11 |
12 | python run_controlnext.py \
13 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \
14 | --validation_control_video_path examples/pose/pose.mp4 \
15 | --output_dir outputs/spiderman \
16 | --controlnext_path pretrained/controlnet.bin \
17 | --unet_path pretrained/unet_fp16.bin \
18 | --ref_image_path examples/ref_imgs/spiderman.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # 🌀 ControlNeXt
3 |
4 |
5 |
6 | ## [📝 Project Page](https://pbihao.github.io/projects/controlnext/index.html) | [📚 Paper](https://arxiv.org/abs/2408.06070) | [🗂️ Demo (SDXL)](https://huggingface.co/spaces/Eugeoter/ControlNeXt)
7 |
8 |
9 | **ControlNeXt** is our official implementation for controllable generation, supporting both images and videos while incorporating diverse forms of control information. In this project, we propose a new method that reduces trainable parameters by up to 90% compared with ControlNet, achieving faster convergence and outstanding efficiency. This method can be directly combined with other LoRA techniques to alter style and ensure more stable generation. Please refer to the examples for more details.
10 |
11 | We provide an online demo of [ControlNeXt-SDXL](./ControlNeXt-SDXL/). Due to the high resource requirements of SVD, we are unable to offer it online.
12 |
13 | > This project is still undergoing iterative development. The code and model may be updated at any time. More information will be provided later.
14 |
15 | # Experiences
16 | We share more training experiences [there](./experiences.md) and in the [Issue](https://github.com/dvlab-research/ControlNeXt/issues/14#issuecomment-2290450333).
17 | We spent a lot of time to find these. Now share with all of you. May these will help you!
18 |
19 | # Model Zoo
20 |
21 | - **ControlNeXt-SDXL** [ [Link](ControlNeXt-SDXL) ] : Controllable image generation. Our model is built upon [Stable Diffusion XL ](stabilityai/stable-diffusion-xl-base-1.0). Fewer trainable parameters, faster convergence, improved efficiency, and can be integrated with LoRA.
22 |
23 | - **ControlNeXt-SDXL-Training** [ [Link](ControlNeXt-SDXL-Training) ] : The training scripts for our `ControlNeXt-SDXL` [ [Link](ControlNeXt-SDXL) ].
24 |
25 | - **ControlNeXt-SVD-v2** [ [Link](ControlNeXt-SVD-v2) ] : Generate the video controlled by the sequence of human poses. In the v2 version, we implement several improvements: a higher-quality collected training dataset, larger training and inference batch frames, higher generation resolution, enhanced human-related video generation through continual training, and pose alignment for inference to improve overall performance.
26 |
27 | - **ControlNeXt-SVD-v2-Training** [ [Link](ControlNeXt-SVD-v2-Training) ] : The training scripts for our `ControlNeXt-SVD-v2` [ [Link](ControlNeXt-SVD-v2) ].
28 |
29 | - **ControlNeXt-SVD** [ [Link](ControlNeXt-SVD) ] : Generate the video controlled by the sequence of human poses. This can be seen as an attempt to replicate the implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone). However, our model is built upon [Stable Video Diffusion](https://stability.ai/stable-video), employing a more concise architecture.
30 |
31 | - **ControlNeXt-SD1.5** [ [Link](ControlNeXt-SD1.5) ] : Controllable image generation. Our model is built upon [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Fewer trainable parameters, faster convergence, improved efficiency, and can be integrated with LoRA.
32 |
33 | - **ControlNeXt-SD1.5-Training** [ [Link](ControlNeXt-SD1.5-Training) ] : The training scripts for our `ControlNeXt-SD1.5` [ [Link](ControlNeXt-SD1.5) ].
34 |
35 | - **ControlNeXt-SD3** [ [Link](ControlNeXt-SD3) ] : We are regret to inform that ControlNeXt-SD3 is trained with protected and private data and code, and therefore cannot be released.
36 |
37 |
38 |
39 | # 🎥 Examples
40 | ### For more examples, please refer to our [Project page](https://pbihao.github.io/projects/controlnext/index.html).
41 |
42 | ### [ControlNeXt-SDXL](ControlNeXt-SDXL)
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | ### [ControlNeXt-SVD-v2](ControlNeXt-SVD-v2)
52 | If you can't load the videos, you can also directly download them from [here](examples/demos) and [here](examples/video).
53 | Or you can view them from our [Project Page](https://pbihao.github.io/projects/controlnext/index.html) or [BiliBili](https://www.bilibili.com/video/BV1wJYbebEE7/?buvid=YC4E03C93B119ADD4080B0958DE73F9DDCAC&from_spmid=dt.dt.video.0&is_story_h5=false&mid=y82Gz7uArS6jTQ6zuqJj3w%3D%3D&p=1&plat_id=114&share_from=ugc&share_medium=iphone&share_plat=ios&share_session_id=4E5549FC-0710-4030-BD2C-CDED80B46D08&share_source=WEIXIN&share_source=weixin&share_tag=s_i×tamp=1723123770&unique_k=XLZLhCq&up_id=176095810&vd_source=3791450598e16da25ecc2477fc7983db).
54 |
55 |
56 |
57 |
58 |
59 | |
60 |
61 |
62 | |
63 |
64 |
65 |
66 |
67 | |
68 |
69 |
70 | |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 | ### [ControlNeXt-SVD](ControlNeXt-SVD)
81 | If you can't load the videos, you can also directly download them from [here](ControlNeXt-SVD/outputs).
82 |
83 |
84 |
85 |
86 |
87 |
90 |
91 |
92 |
93 |
94 |
95 | |
96 |
97 |
98 | |
99 |
100 |
101 |
102 |
103 |
104 | ### [ControlNeXt-SD1.5](ControlNeXt-SD1.5)
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 | ### If you find this work useful, please consider citing:
118 | ```
119 | @article{peng2024controlnext,
120 | title={ControlNeXt: Powerful and Efficient Control for Image and Video Generation},
121 | author={Peng, Bohao and Wang, Jian and Zhang, Yuechen and Li, Wenbo and Yang, Ming-Chang and Jia, Jiaya},
122 | journal={arXiv preprint arXiv:2408.06070},
123 | year={2024}
124 | }
125 | ```
126 |
--------------------------------------------------------------------------------
/compress_image.py:
--------------------------------------------------------------------------------
1 | # from PIL import Image
2 | # import cv2
3 | # import numpy as np
4 |
5 | # img_path = "/home/llm/bhpeng/github/ControlAny/ControlAny-SDXL/examples/vidit_depth/condition_0.png"
6 | # save_path = "/home/llm/bhpeng/github/ControlAny/ControlAny-SDXL/examples/vidit_depth/condition_02.png"
7 |
8 | # length = 1
9 | # select_id = []
10 |
11 | # image = cv2.imread(img_path)
12 | # height, width, _ = image.shape
13 | # part_width = width // length
14 |
15 | # splited_imgs = []
16 | # for i in range(length):
17 | # left = i * part_width
18 | # right = (i + 1) * part_width if i < length - 1 else width # 确保最后一个分块到图像右边界
19 |
20 | # split_img = image[:, left:right]
21 | # splited_imgs.append(split_img)
22 |
23 | # merge_imgs = []
24 | # merge_imgs.append(splited_imgs[0])
25 | # for i in select_id:
26 | # merge_imgs.append(splited_imgs[i])
27 | # merge_imgs = np.concatenate(merge_imgs, axis=1)
28 | # print(merge_imgs.shape)
29 | # resized_img = cv2.resize(merge_imgs, (merge_imgs.shape[1]//2, merge_imgs.shape[0]//2), interpolation=cv2.INTER_AREA)
30 | # print(resized_img.shape)
31 | # cv2.imwrite(save_path, resized_img, [cv2.IMWRITE_JPEG_QUALITY, 85])
32 |
33 | # img = cv2.resize(img, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_AREA)
34 |
35 |
36 | # from moviepy.editor import VideoFileClip
37 | # import moviepy.video.io.ffmpeg_writer as ffmpeg_writer
38 |
39 |
40 | # video_path = 'ControlNeXt-SVD/outputs/chair/chair.mp4'
41 | # clip = VideoFileClip(video_path)
42 |
43 | # gif_path = 'ControlNeXt-SVD/outputs/chair/chair.gif'
44 | # clip.write_gif(gif_path, fps=14, program='ffmpeg', opt="nq", fuzz=1, )
45 |
46 |
47 | from PIL import Image
48 | import os
49 |
50 | def compress_image(input_path, output_path, quality=85):
51 | """
52 | 压缩图片,同时尽可能保留质量。
53 |
54 | :param input_path: 原始图片路径
55 | :param output_path: 压缩后图片保存路径
56 | :param quality: 压缩质量,取值范围是 0 到 100,100 代表最高质量
57 | """
58 | # 打开图片
59 | with Image.open(input_path) as img:
60 | # 确保图片是 RGB 模式
61 | if img.mode in ("RGBA", "P"):
62 | img = img.convert("RGB")
63 |
64 | # 保存压缩后的图片
65 | img.save(output_path, "JPEG", quality=quality, optimize=True)
66 |
67 |
68 | img_paths = [
69 | 'ControlNeXt-SDXL/examples/demo/demo1.png',
70 | 'ControlNeXt-SDXL/examples/demo/demo3.png',
71 | 'ControlNeXt-SDXL/examples/demo/demo5.png'
72 | ]
73 | quality = 50
74 |
75 | for src_path in img_paths:
76 | dst_path = src_path
77 | src_path = os.path.join(os.path.split(src_path)[0], 'src_'+os.path.split(src_path)[1])
78 | os.rename(dst_path, src_path)
79 | dst_path = '.'.join(dst_path.split('.')[:-1])+'.jpg'
80 | compress_image(src_path, dst_path, quality=quality)
--------------------------------------------------------------------------------
/experiences.md:
--------------------------------------------------------------------------------
1 | # 🌀 ControlNeXt-Experiences
2 |
3 | As we all know, developing a high-quality model is not just an academic challenge; it also requires extensive engineering experience. Therefore, we are sharing the insights we gained during this project. These insights are the result of significant time and effort. If you find them helpful, please consider giving us a star or citing our work.
4 |
5 | May they will help you.
6 |
7 | ### 1. Human-related generation
8 |
9 | As I’ve mentioned, we only select a small subset of parameters, which is fully adapted to the SD1.5 and SDXL backbones. By training fewer than 100 million parameters, we still achieve excellent performance. But this is is not suitable for the SD3 and SVD training. This is because, after SDXL, Stability faced significant legal risks due to the generation of highly realistic human images. After that, they stopped refining their models on human-related data, such as SVD and SD3, to avoid potential risks.
10 |
11 | To achieve optimal performance, it's necessary to first continue training SVD and SD3 on human-related data to develop a robust backbone before fine-tuning. Of course, you can also combine the continual pretraining and finetuning (Open all the parameters to train. There will not be a significant differences.). So you can find that we direct provide the full SVD parameters.
12 |
13 | Although this may not be directly related to academia, it is crucial for achieving good performance.
14 |
15 | ### 2. Data
16 |
17 | Due to privacy policies, we are unable to share the data. However, data quality is crucial, as many videos on the internet are highly compressed. It’s important to focus on collecting high-quality data without compression.
18 |
19 | ### 3. Hands
20 |
21 | Generating hands is a challenging problem in both video and image generation. To address this, we focus on the following strategies:
22 |
23 | a. Use clear and high-quality data, which is crucial for accurate generation.
24 |
25 | b. Since the hands occupy a relatively small area, we apply a larger scale for the loss function specifically for this region to improve the generation quality.
26 |
27 | ### 4. Pose alignment
28 |
29 | Thanks [mimic](https://github.com/Tencent/MimicMotion). SVD performs poorly, especially with large motions. Therefore, it is important to avoid large movements and shifts. So please note that in [preprocess](https://github.com/dvlab-research/ControlNeXt/blob/main/ControlNeXt-SVD-v2/dwpose/preprocess.py), there is a alignment between the refenrece image and pose. This is crucial.
30 |
31 |
32 | ### 5. Control level
33 |
34 | You can find that we adopt a magic nuber when adding the conditions.
35 |
36 | Such as in `ControlNeXt-SVD-v2/models/unet_spatio_temporal_condition_controlnext.py`:
37 | ```python
38 | sample = sample + conditional_controls * scale * 0.2
39 | ```
40 |
41 | You can notice that we time a `0.2`. This superparameter is used to adjust the control level: increasing this value will strengthen the control level.
42 |
43 | However, if this value is set too high, the control may become overly strong and may not be apparent in the final generated images.
44 |
45 | So you can adjust it to get a good result. In our experiences, for the dense controls such as super-resolution or depth, we need to set it as `1`.
46 |
47 |
48 | ### 6. Training parameters
49 |
50 | One of the most important findings is that directly training the base model yields better performance compared to methods like LoRA, Adapter, and others.Even when we train the base model, we only select a small subset of the pre-trained parameters. You can also adaptively adjust the number of selected parameters. For example, with high-quality data, having more trainable parameters can improve performance. However, this is a trade-off, and regardless of the approach, directly training the base model often yields the best results.
51 |
52 |
53 | ### If you find this work useful, please consider citing:
54 | ```
55 | @article{peng2024controlnext,
56 | title={ControlNeXt: Powerful and Efficient Control for Image and Video Generation},
57 | author={Peng, Bohao and Wang, Jian and Zhang, Yuechen and Li, Wenbo and Yang, Ming-Chang and Jia, Jiaya},
58 | journal={arXiv preprint arXiv:2408.06070},
59 | year={2024}
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | moviepy
3 | opencv-python
4 | pillow
5 | numpy
6 | transformers
7 | diffusers
8 | safetensors
9 | peft
10 | decord
11 | einops
--------------------------------------------------------------------------------