├── .gitignore ├── ControlNeXt-SD1.5-Training ├── README.md ├── examples │ ├── conditioning_image_1.png │ └── conditioning_image_2.png ├── models │ ├── controlnext.py │ └── unet.py ├── pipeline │ └── pipeline_controlnext.py ├── run_controlnext.py ├── scripts.sh └── train_controlnext.py ├── ControlNeXt-SD1.5 ├── README.md ├── examples │ ├── deepfashion_caption │ │ ├── condition_0.png │ │ ├── condition_1.png │ │ ├── eval_img │ │ │ ├── chinese_style.jpg │ │ │ ├── image_0.jpg │ │ │ ├── warrior_bad.jpg │ │ │ └── warrior_good.jpg │ │ └── script.sh │ ├── deepfashion_multiview │ │ ├── condition_0.jpg │ │ ├── condition_1.jpg │ │ ├── condition_2.jpg │ │ ├── eval_img │ │ │ ├── Anythingv3.jpg │ │ │ ├── Anythingv3_fischl.jpg │ │ │ ├── Anythingv3_lisa.jpg │ │ │ └── DreamShaper.jpg │ │ └── script.sh │ ├── deepfashoin_mask │ │ ├── condition_0.png │ │ ├── eval_img │ │ │ └── image_0.jpg │ │ └── script.sh │ └── vidit_depth │ │ ├── condition_0.png │ │ ├── eval_img │ │ └── image_0.jpg │ │ └── script.sh ├── models │ ├── controlnet.py │ ├── pipeline_controlnext.py │ └── unet.py └── run_controlnext.py ├── ControlNeXt-SD3 └── README.md ├── ControlNeXt-SDXL-Training ├── README.md ├── examples │ └── vidit_depth │ │ ├── condition_0.png │ │ └── train.sh ├── models │ ├── controlnet.py │ └── unet.py ├── pipeline │ └── pipeline_controlnext.py ├── requirements.txt ├── train_controlnext.py └── utils │ ├── preprocess.py │ ├── tools.py │ └── utils.py ├── ControlNeXt-SDXL ├── README.md ├── examples │ ├── anime_canny │ │ ├── condition_0.jpg │ │ ├── eval_img │ │ │ ├── AAM.jpg │ │ │ └── NetaXLV2.jpg │ │ ├── image_0.jpg │ │ ├── run.sh │ │ └── run_with_pp.sh │ ├── demo │ │ ├── demo1.jpg │ │ ├── demo2.jpg │ │ ├── demo3.jpg │ │ ├── demo4.jpg │ │ └── demo5.jpg │ └── vidit_depth │ │ ├── condition_0.png │ │ ├── eval_img │ │ ├── StableDiffusionXL.jpg │ │ └── StableDiffusionXL_GlassSculpturesLora.jpg │ │ └── run.sh ├── models │ ├── controlnet.py │ └── unet.py ├── pipeline │ └── pipeline_controlnext.py ├── requirements.txt ├── run_controlnext.py └── utils │ ├── preprocess.py │ ├── tools.py │ └── utils.py ├── ControlNeXt-SVD-v2-Training ├── README.md ├── deepspeed.yaml ├── meta_info_example │ ├── meta_info.json │ └── meta_info │ │ └── 1.json ├── models │ ├── controlnext_vid_svd.py │ └── unet_spatio_temporal_condition_controlnext.py ├── pipeline │ └── pipeline_stable_video_diffusion_controlnext.py ├── requirements.txt ├── script.sh ├── train_svd.py └── utils │ ├── dataset.py │ ├── extract_learned_paras.py │ ├── extract_vid2img.py │ ├── img_dataset.py │ ├── pkl_dataset.py │ ├── scheduling_euler_discrete_karras_fix.py │ ├── ubc_dataset.py │ ├── unwrap_deepspeed.py │ ├── util.py │ └── vid_dataset.py ├── ControlNeXt-SVD-v2 ├── README.md ├── dwpose │ ├── __init__.py │ ├── dwpose_detector.py │ ├── onnxdet.py │ ├── onnxpose.py │ ├── preprocess.py │ ├── util.py │ └── wholebody.py ├── examples │ ├── demos │ │ ├── 01-1.mp4 │ │ ├── 02-1.mp4 │ │ ├── 03-1.mp4 │ │ └── 04-1.mp4 │ ├── facefusion │ │ └── facefusion.jpg │ ├── ref_imgs │ │ ├── 01.jpeg │ │ ├── 02.jpeg │ │ ├── 03.jpeg │ │ └── 04.jpeg │ └── video │ │ ├── 01.mp4 │ │ └── 02.mp4 ├── models │ ├── controlnext_vid_svd.py │ └── unet_spatio_temporal_condition_controlnext.py ├── pipeline │ └── pipeline_stable_video_diffusion_controlnext.py ├── run_controlnext.py ├── script.sh └── utils │ ├── pre_process.py │ └── scheduling_euler_discrete_karras_fix.py ├── ControlNeXt-SVD ├── README.md ├── examples │ ├── facefusion │ │ └── facefusion.jpg │ ├── pose │ │ └── pose.mp4 │ └── ref_imgs │ │ ├── spiderman.jpg │ │ └── tiktok.png ├── models │ ├── controlnext_vid_svd.py │ └── unet_spatio_temporal_condition_controlnext.py ├── outputs │ ├── chair │ │ └── chair.mp4 │ ├── collected │ │ ├── demo.jpg │ │ ├── demo.mp4 │ │ └── out2.mp4 │ ├── spiderman │ │ └── spiderman.mp4 │ ├── star │ │ └── star.mp4 │ └── tiktok │ │ └── tiktok.mp4 ├── pipeline │ └── pipeline_stable_video_diffusion_controlnext.py ├── run_controlnext.py ├── script.sh └── utils │ └── scheduling_euler_discrete_karras_fix.py ├── LICENSE ├── README.md ├── compress_image.py ├── experiences.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ -------------------------------------------------------------------------------- /ControlNeXt-SD1.5-Training/README.md: -------------------------------------------------------------------------------- 1 | # 🌀 ControlNeXt-SD1.5 2 | 3 | 4 | This is the training script for our ControlNeXt model, based on Stable Diffusion 1.5. 5 | 6 | Our training and inference code has undergone some updates compared to the original version. Please refer to this version as the standard. 7 | 8 | We provide an example using an open dataset, where our method achieves convergence in just a thousand training steps. 9 | 10 | ## Train 11 | 12 | 13 | ``` 14 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --main_process_port 1234 train_controlnext.py \ 15 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 16 | --output_dir="checkpoints" \ 17 | --dataset_name=fusing/fill50k \ 18 | --resolution=512 \ 19 | --learning_rate=1e-5 \ 20 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ 21 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ 22 | --checkpoints_total_limit 3 \ 23 | --checkpointing_steps 400 \ 24 | --validation_steps 400 \ 25 | --num_train_epochs 4 \ 26 | --train_batch_size=6 \ 27 | --controlnext_scale 0.35 \ 28 | --save_load_weights_increaments 29 | ``` 30 | 31 | > --controlnext_scale: Set between [0, 1]; controls the strength of ControlNeXt. A larger value indicates stronger control. For tasks requiring dense conditional controls, such as depth, setting it larger (such as 1.) will provide better control. Increasing this number will lead to faster convergence and stronger control, but it can sometimes overly influence the final generation. 32 | 33 | 34 | > --save_load_weights_increments: Choose whether to save the trainable parameters directly or just the weight increments, i.e., $W_{finetune} - W_{pretrained}$. This is useful when adapting to various backbones. 35 | 36 | ## Inference 37 | 38 | 39 | ``` 40 | python run_controlnext.py \ 41 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 42 | --output_dir="test" \ 43 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ 44 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ 45 | --controlnet_model_name_or_path checkpoints/checkpoint-1400/controlnext.bin \ 46 | --unet_model_name_or_path checkpoints/checkpoint-1200/unet.bin \ 47 | --controlnext_scale 0.35 48 | ``` 49 | 50 | ``` 51 | python run_controlnext.py \ 52 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 53 | --output_dir="test" \ 54 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ 55 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ 56 | --controlnet_model_name_or_path checkpoints/checkpoint-800/controlnext.bin \ 57 | --unet_model_name_or_path checkpoints/checkpoint-1200/unet_weight_increasements.bin \ 58 | --controlnext_scale 0.35 \ 59 | --save_load_weights_increaments 60 | ``` -------------------------------------------------------------------------------- /ControlNeXt-SD1.5-Training/examples/conditioning_image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5-Training/examples/conditioning_image_1.png -------------------------------------------------------------------------------- /ControlNeXt-SD1.5-Training/examples/conditioning_image_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5-Training/examples/conditioning_image_2.png -------------------------------------------------------------------------------- /ControlNeXt-SD1.5-Training/models/controlnext.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D 10 | 11 | 12 | class ControlNeXtModel(ModelMixin, ConfigMixin): 13 | _supports_gradient_checkpointing = True 14 | 15 | @register_to_config 16 | def __init__( 17 | self, 18 | time_embed_dim = 256, 19 | in_channels = [128, 128], 20 | out_channels = [128, 256], 21 | groups = [4, 8], 22 | controlnext_scale=1. 23 | ): 24 | super().__init__() 25 | 26 | self.time_proj = Timesteps(128, True, downscale_freq_shift=0) 27 | self.time_embedding = TimestepEmbedding(128, time_embed_dim) 28 | self.embedding = nn.Sequential( 29 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 30 | nn.GroupNorm(2, 64), 31 | nn.ReLU(), 32 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 33 | nn.GroupNorm(2, 64), 34 | nn.ReLU(), 35 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 36 | nn.GroupNorm(2, 128), 37 | nn.ReLU(), 38 | ) 39 | 40 | self.down_res = nn.ModuleList() 41 | self.down_sample = nn.ModuleList() 42 | for i in range(len(in_channels)): 43 | self.down_res.append( 44 | ResnetBlock2D( 45 | in_channels=in_channels[i], 46 | out_channels=out_channels[i], 47 | temb_channels=time_embed_dim, 48 | groups=groups[i] 49 | ), 50 | ) 51 | self.down_sample.append( 52 | Downsample2D( 53 | out_channels[i], 54 | use_conv=True, 55 | out_channels=out_channels[i], 56 | padding=1, 57 | name="op", 58 | ) 59 | ) 60 | 61 | self.mid_convs = nn.ModuleList() 62 | self.mid_convs.append(nn.Sequential( 63 | nn.Conv2d( 64 | in_channels=out_channels[-1], 65 | out_channels=out_channels[-1], 66 | kernel_size=3, 67 | stride=1, 68 | padding=1 69 | ), 70 | nn.ReLU(), 71 | nn.GroupNorm(8, out_channels[-1]), 72 | nn.Conv2d( 73 | in_channels=out_channels[-1], 74 | out_channels=out_channels[-1], 75 | kernel_size=3, 76 | stride=1, 77 | padding=1 78 | ), 79 | nn.GroupNorm(8, out_channels[-1]), 80 | )) 81 | self.mid_convs.append( 82 | nn.Conv2d( 83 | in_channels=out_channels[-1], 84 | out_channels=320, 85 | kernel_size=1, 86 | stride=1, 87 | )) 88 | 89 | self.scale = controlnext_scale 90 | 91 | def forward( 92 | self, 93 | sample: torch.FloatTensor, 94 | timestep: Union[torch.Tensor, float, int], 95 | ): 96 | 97 | timesteps = timestep 98 | if not torch.is_tensor(timesteps): 99 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 100 | # This would be a good case for the `match` statement (Python 3.10+) 101 | is_mps = sample.device.type == "mps" 102 | if isinstance(timestep, float): 103 | dtype = torch.float32 if is_mps else torch.float64 104 | else: 105 | dtype = torch.int32 if is_mps else torch.int64 106 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 107 | elif len(timesteps.shape) == 0: 108 | timesteps = timesteps[None].to(sample.device) 109 | 110 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 111 | batch_size = sample.shape[0] 112 | timesteps = timesteps.expand(batch_size) 113 | 114 | t_emb = self.time_proj(timesteps) 115 | 116 | # `Timesteps` does not contain any weights and will always return f32 tensors 117 | # but time_embedding might actually be running in fp16. so we need to cast here. 118 | # there might be better ways to encapsulate this. 119 | t_emb = t_emb.to(dtype=sample.dtype) 120 | 121 | emb = self.time_embedding(t_emb) 122 | 123 | sample = self.embedding(sample) 124 | 125 | for res, downsample in zip(self.down_res, self.down_sample): 126 | sample = res(sample, emb) 127 | sample = downsample(sample, emb) 128 | 129 | sample = self.mid_convs[0](sample) + sample 130 | sample = self.mid_convs[1](sample) 131 | 132 | return { 133 | 'output': sample, 134 | 'scale': self.scale, 135 | } 136 | 137 | -------------------------------------------------------------------------------- /ControlNeXt-SD1.5-Training/scripts.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --main_process_port 1234 train_controlnext.py \ 2 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 3 | --output_dir="checkpoints" \ 4 | --dataset_name=fusing/fill50k \ 5 | --resolution=512 \ 6 | --learning_rate=1e-5 \ 7 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ 8 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ 9 | --checkpoints_total_limit 3 \ 10 | --checkpointing_steps 400 \ 11 | --validation_steps 400 \ 12 | --num_train_epochs 4 \ 13 | --train_batch_size=6 \ 14 | --controlnext_scale 0.35 \ 15 | --save_load_weights_increaments 16 | 17 | 18 | 19 | 20 | CUDA_VISIBLE_DEVICES=4 python run_controlnext.py \ 21 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 22 | --output_dir="test" \ 23 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ 24 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ 25 | --controlnet_model_name_or_path checkpoints/checkpoint-1400/controlnext.bin \ 26 | --unet_model_name_or_path checkpoints/checkpoint-1200/unet.bin \ 27 | --controlnext_scale 0.35 28 | 29 | 30 | 31 | CUDA_VISIBLE_DEVICES=5 python run_controlnext.py \ 32 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 33 | --output_dir="test" \ 34 | --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ 35 | --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ 36 | --controlnet_model_name_or_path checkpoints/checkpoint-400/controlnext.bin \ 37 | --unet_model_name_or_path checkpoints/checkpoint-400/unet_weight_increasements.bin \ 38 | --controlnext_scale 0.35 \ 39 | --save_load_weights_increaments -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/README.md: -------------------------------------------------------------------------------- 1 | # 🌀 ControlNeXt-SD1.5 2 | 3 | `Please refer to SDXL and SVD for our newly updated version !` 4 | 5 | This is our implementation of ControlNeXt based on [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). 6 | 7 | > Please refer to [Examples](#examples) for further intuitive details.\ 8 | > Please refer to [Inference](#inference) for more details regarding installation and inference.\ 9 | > Please refer to [Model Zoo](#model-zoo) for more other our trained models. 10 | 11 | Our method demonstrates the advantages listed below: 12 | 13 | - **Few trainable parameters**: only requiring **5~30M** trainable parameters (occupying 20~80 MB of memory). 14 | - **Fast training speed**: no sudden convergence. 15 | - **Efficient**: no need for additional branch; only a lightweight module is required. 16 | - **Compatibility**: can serve as a **plug-and-play** lightweight module and can be combined with other LoRA weights. 17 | 18 | # Examples 19 | 20 | The demo examples are generated using the ControlNeXt trained on deepfashion_multiview dataset with utilizing [DreamShaper](https://huggingface.co/Lykon/DreamShaper) as the base model. Our method demonstrates excellent compatibility and can be applied to most other models based on sd1.5 architecture and LoRA. And you can retrain your own model for better performance. 21 | 22 | ## BaseModel 23 | 24 | Our model can be applied to various base models without the need for futher training as a plug-and-play module. 25 | 26 | > 📌 Of course, you can retrain your owm model, especially for complex tasks and to achieve better performance. 27 | 28 | - [DreamShaper](https://huggingface.co/Lykon/DreamShaper) 29 | 30 |

31 | DreamShaper 32 |

33 | 34 | - [Anything-v3.0](https://huggingface.co/admruul/anything-v3.0) 35 | 36 |

37 | Anythingv3 38 |

39 | 40 | ## LoRA 41 | 42 | Our model can also be directly combined with other publicly available LoRA weights. 43 | 44 | - [Lisa](https://civitai.com/articles/4584) 45 | 46 |

47 | Anythingv3 48 |

49 | 50 | - [Fischl](https://civitai.com/articles/4584) 51 | 52 |

53 | Anythingv3 54 |

55 | 56 | - [Chinese Style](https://civitai.com/models/12597/moxin) 57 | 58 |

59 | Anythingv3 60 |

61 | 62 | ## Stable Generation 63 | 64 | Sometimes, it is difficult to generate good results, and you have to repeatedly adjust your prompt to achieve satisfactory outcomes. However, this process is challenging because prompts are very abstract. 65 | 66 | Our method can serve as a plug-and-play module for stable generation. 67 | 68 | - Without ControlNeXt (Use original [SD](https://huggingface.co/runwayml/stable-diffusion-v1-5) as base model) 69 |

70 | Anythingv3 71 |

72 | 73 | - With ControlNeXt (Use original [SD](https://huggingface.co/runwayml/stable-diffusion-v1-5) as base model) 74 |

75 | Anythingv3 76 |

77 | 78 | # Inference 79 | 80 | 1. Clone our repository 81 | 2. `cd ControlNeXt-SD1.5` 82 | 3. Download the pretrained weight into `pretrained/` from [here](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SD1.5). (Recommended to use `deepfashion_multiview` and `deepfashion_caption`) 83 | 4. (Optional) Download the LoRA weight, such as [Genshin](https://civitai.com/models/362091/sd15all-characters-genshin-impact-124-characters-124). And put them under `lora/` 84 | 5. Run the scipt 85 | 86 | ```python 87 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 88 | --pretrained_model_name_or_path="admruul/anything-v3.0" \ 89 | --output_dir="examples/deepfashion_multiview" \ 90 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \ 91 | --validation_prompt "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" \ 92 | --negative_prompt "PBH" "PBH"\ 93 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \ 94 | (Optional)--lora_path lora/yuanshen/genshin_124.safetensors \ 95 | (Optional, less generality, stricter control)--unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors 96 | ``` 97 | 98 | > --pretrained_model_name_or_path : pretrained base model, we try on [DreamShaper](https://huggingface.co/Lykon/DreamShaper), [Anything-v3.0](https://huggingface.co/admruul/anything-v3.0), [ori SD](https://huggingface.co/runwayml/stable-diffusion-v1-5) \ 99 | > --controlnet_model_name_or_path : the model path of controlnet (a light weight module) \ 100 | > --lora_path : downloaded other LoRA weight \ 101 | > --unet_model_name_or_path : the model path of a subset of unet parameters 102 | 103 | > 📌 Pose-based generation is a relative simple task. And in most cases, it is enough to just load the control module by `--controlnet_model_name_or_path`. However, sometime the task is hard so it is need to select some subset of the original unet parameters to fit the task (Can be seen as another kind of LoRA). \ 104 | > More parameters mean weaker generality, so you can make your own tradeoff. Or directly train your own models based on your own data. The training is also fast. 105 | 106 | # Model Zoo 107 | 108 | We also provide some additional examples, but these are just for demonstration purposes. The training data is relatively small and of low quality. 109 | 110 | - vidit_depth 111 |

112 | Anythingv3 113 |

114 | 115 | - mask 116 |

117 | Anythingv3 118 |

119 | 120 | # TODO 121 | -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_caption/condition_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/condition_0.png -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_caption/condition_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/condition_1.png -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/chinese_style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/chinese_style.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/image_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/image_0.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/warrior_bad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/warrior_bad.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/warrior_good.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_caption/eval_img/warrior_good.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_caption/script.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 2 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 3 | --output_dir="examples/deepfashion_caption" \ 4 | --validation_image "examples/deepfashion_caption/condition_1.png" "examples/deepfashion_caption/condition_0.png" \ 5 | --validation_prompt "a woman wearing a black shirt and black leather skirt" "levi's women's white graphic t - shirt" \ 6 | --controlnet_model_name_or_path pretrained/deepfashion_caption/controlnet.safetensors \ 7 | --unet_model_name_or_path pretrained/deepfashion_caption/unet.safetensors 8 | -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_0.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_1.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/condition_2.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3_fischl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3_fischl.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3_lisa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/Anythingv3_lisa.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/DreamShaper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashion_multiview/eval_img/DreamShaper.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashion_multiview/script.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 2 | --pretrained_model_name_or_path="admruul/anything-v3.0" \ 3 | --output_dir="examples/deepfashion_multiview" \ 4 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \ 5 | --validation_prompt "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" \ 6 | --negative_prompt "PBH" "PBH"\ 7 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \ 8 | --lora_path lora/yuanshen/genshin_124.safetensors \ 9 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors 10 | 11 | 12 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 13 | --pretrained_model_name_or_path="admruul/anything-v3.0" \ 14 | --output_dir="examples/deepfashion_multiview" \ 15 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \ 16 | --validation_prompt "1boy, braid, single_earring, short_sleeves, white_scarf, black_gloves, alternate_costume, standing, black_shirt" "1boy, braid, single_earring, short_sleeves, white_scarf, black_gloves, alternate_costume, standing, black_shirt" \ 17 | --negative_prompt "PBH" "PBH"\ 18 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \ 19 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors 20 | 21 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 22 | --pretrained_model_name_or_path="admruul/anything-v3.0" \ 23 | --output_dir="examples/deepfashion_multiview" \ 24 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \ 25 | --validation_prompt "lisa_\(a_sobriquet_under_shade\)_\(genshin_impact\), lisa_\(genshin_impact\), 1girl, green_headwear, official_alternate_costume, cleavage, twin_braids, hair_flower, vision_\(genshin_impact\), large_breasts, thighlet, puffy_long_sleeves, purple_rose, beret" "lisa_\(a_sobriquet_under_shade\)_\(genshin_impact\), lisa_\(genshin_impact\), 1girl, green_headwear, official_alternate_costume, cleavage, twin_braids, hair_flower, vision_\(genshin_impact\), large_breasts, thighlet, puffy_long_sleeves, purple_rose, beret" \ 26 | --negative_prompt "PBH" "PBH"\ 27 | --lora_path lora/yuanshen/genshin_124.safetensors \ 28 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \ 29 | 30 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 31 | --pretrained_model_name_or_path="admruul/anything-v3.0" \ 32 | --output_dir="examples/deepfashion_multiview" \ 33 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \ 34 | --validation_prompt "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" "fischl_\(genshin_impact\), fischl_\(ein_immernachtstraum\)_\(genshin_impact\), official_alternate_costume, 1girl, eyepatch, detached_sleeves, tiara, hair_over_one_eye, bare_shoulders, purple_dress, white_thighhighs, long_sleeves, hair_ribbon, purple_ribbon, white_pantyhose" \ 35 | --negative_prompt "PBH" "PBH"\ 36 | --lora_path lora/yuanshen/genshin_124.safetensors \ 37 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \ 38 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors 39 | 40 | 41 | # base generation 42 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 43 | --pretrained_model_name_or_path="Lykon/DreamShaper" \ 44 | --output_dir="examples/deepfashion_multiview" \ 45 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg"\ 46 | --validation_prompt "a woman in white shorts and a tank top" "a woman wearing a black shirt and black leather skirt" \ 47 | --negative_prompt "PBH" "PBH" \ 48 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \ 49 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors 50 | 51 | 52 | # Combine with LoRA 53 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 54 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 55 | --output_dir="examples/deepfashion_multiview" \ 56 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg"\ 57 | --negative_prompt "PBH" "PBH" \ 58 | --validation_prompt "c1bo, a woman, Armor, weapon, beautiful" "c1bo, a man, fight" \ 59 | --lora_path lora/c1bo/cyborg_v_2_SD15.safetensors \ 60 | --controlnet_model_name_or_path pretrained/deepfashion_multiview/controlnet.safetensors \ 61 | --unet_model_name_or_path pretrained/deepfashion_multiview/unet.safetensors 62 | 63 | # Combine with LoRA, without our control 64 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 65 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 66 | --output_dir="examples/deepfashion_multiview" \ 67 | --validation_image "examples/deepfashion_multiview/condition_0.jpg" "examples/deepfashion_multiview/condition_1.jpg" \ 68 | --negative_prompt "PBH" "PBH" \ 69 | --validation_prompt "c1bo, a woman, Armor, weapon, beautiful" "c1bo, a man, fight" \ 70 | --lora_path lora/c1bo/cyborg_v_2_SD15.safetensors 71 | 72 | 73 | -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashoin_mask/condition_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashoin_mask/condition_0.png -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashoin_mask/eval_img/image_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/deepfashoin_mask/eval_img/image_0.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/deepfashoin_mask/script.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 2 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 3 | --output_dir="examples/deepfashoin_mask" \ 4 | --validation_image "examples/deepfashoin_mask/condition_0.png" \ 5 | --validation_prompt "a woman in white shorts and a tank top" \ 6 | --controlnet_model_name_or_path pretrained/deepfashoin_mask/controlnet.safetensors \ 7 | --unet_model_name_or_path pretrained/deepfashoin_mask/unet.safetensors -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/vidit_depth/condition_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/vidit_depth/condition_0.png -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/vidit_depth/eval_img/image_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SD1.5/examples/vidit_depth/eval_img/image_0.jpg -------------------------------------------------------------------------------- /ControlNeXt-SD1.5/examples/vidit_depth/script.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 2 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 3 | --output_dir="examples/vidit_depth" \ 4 | --validation_image "examples/vidit_depth/condition_0.png" \ 5 | --validation_prompt "a wooden bridge in the middle of a field" \ 6 | --controlnet_model_name_or_path pretrained/vidit_depth/controlnet.safetensors \ 7 | --unet_model_name_or_path pretrained/vidit_depth/unet.safetensors -------------------------------------------------------------------------------- /ControlNeXt-SD3/README.md: -------------------------------------------------------------------------------- 1 | # ControlNeXt-SD3 2 | 3 | We regret to inform you that our model of the Stable Diffusion 3 is trained with protected and private data and code, and therefore cannot be released. However, we are considering releasing another model based on the DiT structure in the future. 4 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL-Training/README.md: -------------------------------------------------------------------------------- 1 | # 🌀 ControlNeXt-SDXL 2 | 3 | This is our **training** demo of ControlNeXt based on [Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0). 4 | 5 | Hardware requirement: A single GPU with at least 20GB memory. 6 | 7 | ## Quick Start 8 | 9 | Clone the repository: 10 | 11 | ```bash 12 | git clone https://github.com/dvlab-research/ControlNeXt 13 | cd ControlNeXt/ControlNeXt-SDXL-Training 14 | ``` 15 | 16 | Install the required packages: 17 | 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | Run the training script: 23 | 24 | ```bash 25 | bash examples/vidit_depth/train.sh 26 | ``` 27 | 28 | The output will be saved in `train/example`. 29 | 30 | ## Usage 31 | 32 | We recommend to only save & load the weights difference of the UNet's trainable parameters, i.e., $\Delta W = W_{finetune} - W_{pretrained}$, rather than the actual weight. 33 | This is useful when adapting to various base models since the weights difference is model-agnostic. 34 | 35 | ```python 36 | accelerate launch train_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \ 37 | --pretrained_vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \ 38 | --variant fp16 \ 39 | --use_safetensors \ 40 | --output_dir "train/example" \ 41 | --logging_dir "logs" \ 42 | --resolution 1024 \ 43 | --gradient_checkpointing \ 44 | --set_grads_to_none \ 45 | --proportion_empty_prompts 0.2 \ 46 | --controlnet_scale_factor 1.0 \ # the strength of the controlnet output. For depth, we recommend 1.0, and for canny, we recommend 0.35 47 | --save_weights_increaments \ 48 | --mixed_precision fp16 \ 49 | --enable_xformers_memory_efficient_attention \ 50 | --dataset_name "Nahrawy/VIDIT-Depth-ControlNet" \ 51 | --image_column "image" \ 52 | --conditioning_image_column "depth_map" \ 53 | --caption_column "caption" \ 54 | --validation_prompt "a stone tower on a rocky island" \ 55 | --validation_image "examples/vidit_depth/condition_0.png" 56 | ``` 57 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL-Training/examples/vidit_depth/condition_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL-Training/examples/vidit_depth/condition_0.png -------------------------------------------------------------------------------- /ControlNeXt-SDXL-Training/examples/vidit_depth/train.sh: -------------------------------------------------------------------------------- 1 | accelerate launch train_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \ 2 | --pretrained_vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \ 3 | --variant fp16 \ 4 | --use_safetensors \ 5 | --output_dir "train/example" \ 6 | --logging_dir "logs" \ 7 | --resolution 1024 \ 8 | --gradient_checkpointing \ 9 | --set_grads_to_none \ 10 | --proportion_empty_prompts 0.2 \ 11 | --controlnet_scale_factor 1.0 \ 12 | --mixed_precision fp16 \ 13 | --enable_xformers_memory_efficient_attention \ 14 | --dataset_name "Nahrawy/VIDIT-Depth-ControlNet" \ 15 | --image_column "image" \ 16 | --conditioning_image_column "depth_map" \ 17 | --caption_column "caption" \ 18 | --validation_prompt "a stone tower on a rocky island" \ 19 | --validation_image "examples/vidit_depth/condition_0.png" -------------------------------------------------------------------------------- /ControlNeXt-SDXL-Training/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | accelerate 4 | opencv-python 5 | pillow 6 | numpy 7 | transformers 8 | diffusers 9 | safetensors 10 | peft 11 | xformers 12 | huggingface-hub 13 | datasets 14 | einops -------------------------------------------------------------------------------- /ControlNeXt-SDXL-Training/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | def get_extractor(extractor_name): 7 | if extractor_name is None: 8 | return None 9 | if extractor_name not in EXTRACTORS: 10 | raise ValueError(f"Extractor {extractor_name} is not supported.") 11 | return EXTRACTORS[extractor_name] 12 | 13 | 14 | def canny_extractor(image: Image.Image, threshold1=None, threshold2=None) -> Image.Image: 15 | image = np.array(image) 16 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 17 | v = np.median(gray) 18 | 19 | sigma = 0.33 20 | threshold1 = threshold1 or int(max(0, (1.0 - sigma) * v)) 21 | threshold2 = threshold2 or int(min(255, (1.0 + sigma) * v)) 22 | 23 | edges = cv2.Canny(gray, threshold1, threshold2) 24 | edges = Image.fromarray(edges).convert("RGB") 25 | return edges 26 | 27 | 28 | def depth_extractor(image: Image.Image): 29 | raise NotImplementedError("Depth extractor is not implemented yet.") 30 | 31 | 32 | def pose_extractor(image: Image.Image): 33 | raise NotImplementedError("Pose extractor is not implemented yet.") 34 | 35 | 36 | EXTRACTORS = { 37 | "canny": canny_extractor, 38 | } 39 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL-Training/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import torch 4 | from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel 5 | from safetensors.torch import load_file 6 | from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline 7 | from models.unet import UNet2DConditionModel 8 | from models.controlnet import ControlNetModel 9 | from . import utils 10 | 11 | UNET_CONFIG = { 12 | "act_fn": "silu", 13 | "addition_embed_type": "text_time", 14 | "addition_embed_type_num_heads": 64, 15 | "addition_time_embed_dim": 256, 16 | "attention_head_dim": [ 17 | 5, 18 | 10, 19 | 20 20 | ], 21 | "block_out_channels": [ 22 | 320, 23 | 640, 24 | 1280 25 | ], 26 | "center_input_sample": False, 27 | "class_embed_type": None, 28 | "class_embeddings_concat": False, 29 | "conv_in_kernel": 3, 30 | "conv_out_kernel": 3, 31 | "cross_attention_dim": 2048, 32 | "cross_attention_norm": None, 33 | "down_block_types": [ 34 | "DownBlock2D", 35 | "CrossAttnDownBlock2D", 36 | "CrossAttnDownBlock2D" 37 | ], 38 | "downsample_padding": 1, 39 | "dual_cross_attention": False, 40 | "encoder_hid_dim": None, 41 | "encoder_hid_dim_type": None, 42 | "flip_sin_to_cos": True, 43 | "freq_shift": 0, 44 | "in_channels": 4, 45 | "layers_per_block": 2, 46 | "mid_block_only_cross_attention": None, 47 | "mid_block_scale_factor": 1, 48 | "mid_block_type": "UNetMidBlock2DCrossAttn", 49 | "norm_eps": 1e-05, 50 | "norm_num_groups": 32, 51 | "num_attention_heads": None, 52 | "num_class_embeds": None, 53 | "only_cross_attention": False, 54 | "out_channels": 4, 55 | "projection_class_embeddings_input_dim": 2816, 56 | "resnet_out_scale_factor": 1.0, 57 | "resnet_skip_time_act": False, 58 | "resnet_time_scale_shift": "default", 59 | "sample_size": 128, 60 | "time_cond_proj_dim": None, 61 | "time_embedding_act_fn": None, 62 | "time_embedding_dim": None, 63 | "time_embedding_type": "positional", 64 | "timestep_post_act": None, 65 | "transformer_layers_per_block": [ 66 | 1, 67 | 2, 68 | 10 69 | ], 70 | "up_block_types": [ 71 | "CrossAttnUpBlock2D", 72 | "CrossAttnUpBlock2D", 73 | "UpBlock2D" 74 | ], 75 | "upcast_attention": None, 76 | "use_linear_projection": True 77 | } 78 | 79 | CONTROLNET_CONFIG = { 80 | 'in_channels': [128, 128], 81 | 'out_channels': [128, 256], 82 | 'groups': [4, 8], 83 | 'time_embed_dim': 256, 84 | 'final_out_channels': 320, 85 | '_use_default_values': ['time_embed_dim', 'groups', 'in_channels', 'final_out_channels', 'out_channels'] 86 | } 87 | 88 | 89 | def get_pipeline( 90 | pretrained_model_name_or_path, 91 | unet_model_name_or_path, 92 | controlnet_model_name_or_path, 93 | vae_model_name_or_path=None, 94 | lora_path=None, 95 | load_weight_increasement=False, 96 | enable_xformers_memory_efficient_attention=False, 97 | revision=None, 98 | variant=None, 99 | hf_cache_dir=None, 100 | use_safetensors=True, 101 | device=None, 102 | ): 103 | pipeline_init_kwargs = {} 104 | 105 | print(f"loading unet from {pretrained_model_name_or_path}") 106 | if os.path.isfile(pretrained_model_name_or_path): 107 | # load unet from local checkpoint 108 | unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path) 109 | unet_sd = utils.extract_unet_state_dict(unet_sd) 110 | unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd) 111 | unet = UNet2DConditionModel.from_config(UNET_CONFIG) 112 | unet.load_state_dict(unet_sd, strict=True) 113 | else: 114 | unet = UNet2DConditionModel.from_pretrained( 115 | pretrained_model_name_or_path, 116 | cache_dir=hf_cache_dir, 117 | variant=variant, 118 | torch_dtype=torch.float16, 119 | use_safetensors=use_safetensors, 120 | subfolder="unet", 121 | ) 122 | unet = unet.to(dtype=torch.float16) 123 | pipeline_init_kwargs["unet"] = unet 124 | 125 | if vae_model_name_or_path is not None: 126 | print(f"loading vae from {vae_model_name_or_path}") 127 | vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device) 128 | pipeline_init_kwargs["vae"] = vae 129 | 130 | if controlnet_model_name_or_path is not None: 131 | pipeline_init_kwargs["controlnet"] = ControlNetModel.from_config(CONTROLNET_CONFIG).to(device, dtype=torch.float32) # init 132 | 133 | print(f"loading pipeline from {pretrained_model_name_or_path}") 134 | if os.path.isfile(pretrained_model_name_or_path): 135 | pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file( 136 | pretrained_model_name_or_path, 137 | use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"), 138 | local_files_only=True, 139 | cache_dir=hf_cache_dir, 140 | **pipeline_init_kwargs, 141 | ) 142 | else: 143 | pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained( 144 | pretrained_model_name_or_path, 145 | revision=revision, 146 | variant=variant, 147 | use_safetensors=use_safetensors, 148 | cache_dir=hf_cache_dir, 149 | **pipeline_init_kwargs, 150 | ) 151 | 152 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) 153 | if unet_model_name_or_path is not None: 154 | print(f"loading controlnext unet from {unet_model_name_or_path}") 155 | pipeline.load_controlnext_unet_weights( 156 | unet_model_name_or_path, 157 | load_weight_increasement=load_weight_increasement, 158 | use_safetensors=True, 159 | torch_dtype=torch.float16, 160 | cache_dir=hf_cache_dir, 161 | ) 162 | if controlnet_model_name_or_path is not None: 163 | print(f"loading controlnext controlnet from {controlnet_model_name_or_path}") 164 | pipeline.load_controlnext_controlnet_weights( 165 | controlnet_model_name_or_path, 166 | use_safetensors=True, 167 | torch_dtype=torch.float32, 168 | cache_dir=hf_cache_dir, 169 | ) 170 | pipeline.set_progress_bar_config() 171 | pipeline = pipeline.to(device, dtype=torch.float16) 172 | 173 | if lora_path is not None: 174 | pipeline.load_lora_weights(lora_path) 175 | if enable_xformers_memory_efficient_attention: 176 | pipeline.enable_xformers_memory_efficient_attention() 177 | 178 | gc.collect() 179 | if str(device) == 'cuda' and torch.cuda.is_available(): 180 | torch.cuda.empty_cache() 181 | 182 | return pipeline 183 | 184 | 185 | def get_scheduler( 186 | scheduler_name, 187 | scheduler_config, 188 | ): 189 | if scheduler_name == 'Euler A': 190 | from diffusers.schedulers import EulerAncestralDiscreteScheduler 191 | scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) 192 | elif scheduler_name == 'UniPC': 193 | from diffusers.schedulers import UniPCMultistepScheduler 194 | scheduler = UniPCMultistepScheduler.from_config(scheduler_config) 195 | elif scheduler_name == 'Euler': 196 | from diffusers.schedulers import EulerDiscreteScheduler 197 | scheduler = EulerDiscreteScheduler.from_config(scheduler_config) 198 | elif scheduler_name == 'DDIM': 199 | from diffusers.schedulers import DDIMScheduler 200 | scheduler = DDIMScheduler.from_config(scheduler_config) 201 | elif scheduler_name == 'DDPM': 202 | from diffusers.schedulers import DDPMScheduler 203 | scheduler = DDPMScheduler.from_config(scheduler_config) 204 | else: 205 | raise ValueError(f"Unknown scheduler: {scheduler_name}") 206 | return scheduler 207 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL-Training/utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, Union, Optional 3 | 4 | 5 | def make_unet_conversion_map(): 6 | unet_conversion_map_layer = [] 7 | 8 | for i in range(3): # num_blocks is 3 in sdxl 9 | # loop over downblocks/upblocks 10 | for j in range(2): 11 | # loop over resnets/attentions for downblocks 12 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 13 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 14 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 15 | 16 | if i < 3: 17 | # no attention layers in down_blocks.3 18 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 19 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 20 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 21 | 22 | for j in range(3): 23 | # loop over resnets/attentions for upblocks 24 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 25 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 26 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 27 | 28 | # if i > 0: commentout for sdxl 29 | # no attention layers in up_blocks.0 30 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 31 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 32 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 33 | 34 | if i < 3: 35 | # no downsample in down_blocks.3 36 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 37 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 38 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 39 | 40 | # no upsample in up_blocks.3 41 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 42 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl 43 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 44 | 45 | hf_mid_atn_prefix = "mid_block.attentions.0." 46 | sd_mid_atn_prefix = "middle_block.1." 47 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 48 | 49 | for j in range(2): 50 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 51 | sd_mid_res_prefix = f"middle_block.{2*j}." 52 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 53 | 54 | unet_conversion_map_resnet = [ 55 | # (stable-diffusion, HF Diffusers) 56 | ("in_layers.0.", "norm1."), 57 | ("in_layers.2.", "conv1."), 58 | ("out_layers.0.", "norm2."), 59 | ("out_layers.3.", "conv2."), 60 | ("emb_layers.1.", "time_emb_proj."), 61 | ("skip_connection.", "conv_shortcut."), 62 | ] 63 | 64 | unet_conversion_map = [] 65 | for sd, hf in unet_conversion_map_layer: 66 | if "resnets" in hf: 67 | for sd_res, hf_res in unet_conversion_map_resnet: 68 | unet_conversion_map.append((sd + sd_res, hf + hf_res)) 69 | else: 70 | unet_conversion_map.append((sd, hf)) 71 | 72 | for j in range(2): 73 | hf_time_embed_prefix = f"time_embedding.linear_{j+1}." 74 | sd_time_embed_prefix = f"time_embed.{j*2}." 75 | unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) 76 | 77 | for j in range(2): 78 | hf_label_embed_prefix = f"add_embedding.linear_{j+1}." 79 | sd_label_embed_prefix = f"label_emb.0.{j*2}." 80 | unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) 81 | 82 | unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) 83 | unet_conversion_map.append(("out.0.", "conv_norm_out.")) 84 | unet_conversion_map.append(("out.2.", "conv_out.")) 85 | 86 | return unet_conversion_map 87 | 88 | 89 | def convert_unet_state_dict(src_sd, conversion_map): 90 | converted_sd = {} 91 | for src_key, value in src_sd.items(): 92 | src_key_fragments = src_key.split(".")[:-1] # remove weight/bias 93 | while len(src_key_fragments) > 0: 94 | src_key_prefix = ".".join(src_key_fragments) + "." 95 | if src_key_prefix in conversion_map: 96 | converted_prefix = conversion_map[src_key_prefix] 97 | converted_key = converted_prefix + src_key[len(src_key_prefix):] 98 | converted_sd[converted_key] = value 99 | break 100 | src_key_fragments.pop(-1) 101 | assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" 102 | 103 | return converted_sd 104 | 105 | 106 | def convert_sdxl_unet_state_dict_to_diffusers(sd): 107 | unet_conversion_map = make_unet_conversion_map() 108 | 109 | conversion_dict = {sd: hf for sd, hf in unet_conversion_map} 110 | return convert_unet_state_dict(sd, conversion_dict) 111 | 112 | 113 | def extract_unet_state_dict(state_dict): 114 | unet_sd = {} 115 | UNET_KEY_PREFIX = "model.diffusion_model." 116 | for k, v in state_dict.items(): 117 | if k.startswith(UNET_KEY_PREFIX): 118 | unet_sd[k[len(UNET_KEY_PREFIX):]] = v 119 | return unet_sd 120 | 121 | 122 | def log_model_info(model, name): 123 | sd = model.state_dict() if hasattr(model, "state_dict") else model 124 | print( 125 | f"{name}:", 126 | f" number of parameters: {sum(p.numel() for p in sd.values())}", 127 | f" dtype: {sd[next(iter(sd))].dtype}", 128 | sep='\n' 129 | ) 130 | 131 | 132 | def around_reso(img_w, img_h, reso: Union[Tuple[int, int], int], divisible: Optional[int] = None, max_width=None, max_height=None) -> Tuple[int, int]: 133 | r""" 134 | w*h = reso*reso 135 | w/h = img_w/img_h 136 | => w = img_ar*h 137 | => img_ar*h^2 = reso 138 | => h = sqrt(reso / img_ar) 139 | """ 140 | reso = reso if isinstance(reso, tuple) else (reso, reso) 141 | divisible = divisible or 1 142 | if img_w * img_h <= reso[0] * reso[1] and (not max_width or img_w <= max_width) and (not max_height or img_h <= max_height) and img_w % divisible == 0 and img_h % divisible == 0: 143 | return (img_w, img_h) 144 | img_ar = img_w / img_h 145 | around_h = math.sqrt(reso[0]*reso[1] / img_ar) 146 | around_w = img_ar * around_h // divisible * divisible 147 | if max_width and around_w > max_width: 148 | around_h = around_h * max_width // around_w 149 | around_w = max_width 150 | elif max_height and around_h > max_height: 151 | around_w = around_w * max_height // around_h 152 | around_h = max_height 153 | around_h = min(around_h, max_height) if max_height else around_h 154 | around_w = min(around_w, max_width) if max_width else around_w 155 | around_h = int(around_h // divisible * divisible) 156 | around_w = int(around_w // divisible * divisible) 157 | return (around_w, around_h) 158 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL/README.md: -------------------------------------------------------------------------------- 1 | # 🌀 ControlNeXt-SDXL 2 | 3 | This is our implementation of ControlNeXt based on [Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0). 4 | 5 | > Please refer to [Examples](#examples) for further intuitive details.\ 6 | > Please refer to [Inference](#inference) for more details regarding installation and inference.\ 7 | 8 | Our method demonstrates the advantages listed below: 9 | 10 | - **Few trainable parameters**: only requiring **5~200M** trainable parameters. 11 | - **Fast training speed**: reduce sudden convergence. 12 | - **Efficient**: no need for additional brunch; only a lightweight module is required. 13 | - **Compatibility**: can serve as a **plug-and-play** lightweight module and can be combined with other LoRA weights. 14 | 15 | # Examples 16 | 17 | The demo examples are generated using the ControlNeXt trained on 18 | 19 | - (i) our vidit_depth dataset with utilizing [Stable Diffusion XL 1.0 Base](stabilityai/stable-diffusion-xl-base-1.0) as the base model. 20 | - (ii) our anime_canny dataset with utilizing [Neta Art XL 2.0](https://civitai.com/models/410737/neta-art-xl) as the base model. 21 | 22 | Our method demonstrates excellent compatibility and can be applied to most other models based on SDXL1.0 architecture and LoRA. And you can retrain your own model for better performance. 23 | 24 |

25 | demo1 26 | demo2 27 | demo3 28 | demo5 29 |

30 | 31 | ## BaseModel 32 | 33 | Our model can be applied to various base models without the need for futher training as a plug-and-play module. 34 | 35 | > 📌 Of course, you can retrain your owm model, especially for complex tasks and to achieve better performance. 36 | 37 | - [Stable Diffusion XL 1.0 Base](stabilityai/stable-diffusion-xl-base-1.0) 38 | 39 |

40 | StableDiffusionXL 41 |

42 | 43 | - [AAM XL](https://huggingface.co/Lykon/AAM_XL_AnimeMix) 44 | 45 |

46 | AAM 47 |

48 | 49 | - [Neta XL V2](https://civitai.com/models/410737/neta-art-xl) 50 | 51 |

52 | NetaXLV2 53 |

54 | 55 | ## LoRA 56 | 57 | Our model can also be directly combined with other publicly available LoRA weights. 58 | 59 | - [Glass Sculptures](https://civitai.com/models/11203/glass-sculptures?modelVersionId=177888) 60 | 61 |

62 | StableDiffusionXL 63 |

64 | 65 | # Inference 66 | 67 | ## Quick Start 68 | 69 | Clone the repository: 70 | 71 | ```bash 72 | git clone https://github.com/dvlab-research/ControlNeXt 73 | cd ControlNeXt/ControlNeXt-SDXL 74 | ``` 75 | 76 | Install the required packages: 77 | 78 | ```bash 79 | pip install -r requirements.txt 80 | ``` 81 | 82 | (Optional) Download the LoRA weight, such as [Amiya (Arknights) Fresh Art Style](https://civitai.com/models/231598/amiya-arknights-fresh-art-style-xl-trained-with-6k-images). And put them under `lora/`. 83 | 84 | Run the example: 85 | 86 | ```bash 87 | bash examples/anime_canny/run.sh 88 | ``` 89 | 90 | ## Usage 91 | 92 | ### Canny Condition 93 | 94 | ```python 95 | # examples/anime_canny/run.sh 96 | python run_controlnext.py --pretrained_model_name_or_path "neta-art/neta-xl-2.0" \ 97 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \ 98 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \ 99 | --controlnet_scale 1.0 \ # controlnet scale factor used to adjust the strength of the control condition 100 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \ 101 | --validation_prompt "3d style, photorealistic style, 1girl, arknights, amiya (arknights), solo, white background, upper body, looking at viewer, blush, closed mouth, low ponytail, black jacket, hooded jacket, open jacket, hood down, blue neckwear" \ 102 | --negative_prompt "worst quality, abstract, clumsy pose, deformed hand, fused fingers, extra digits, fewer digits, fewer fingers, extra fingers, extra arm, missing arm, extra leg, missing leg, signature, artist name, multi views, disfigured, ugly" \ 103 | --validation_image "examples/anime_canny/condition_0.png" \ # input canny image 104 | --output_dir "examples/anime_canny" \ 105 | --load_weight_increasement # load weight increasement 106 | ``` 107 | 108 | We use a `controlnet_scale` factor to adjust the strength of the control condition. 109 | 110 | We recommend to only save & load the weights difference of the UNet's trainable parameters, i.e., $\Delta W = W_{finetune} - W_{pretrained}$, rather than the actual weight. 111 | This is useful when adapting to various base models since the weights difference is model-agnostic. 112 | 113 | ### Depth Condition 114 | 115 | ```python 116 | # examples/vidit_depth/run.sh 117 | python run_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \ 118 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-vidit-depth" \ 119 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-vidit-depth" \ 120 | --controlnet_scale 1.0 \ 121 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \ 122 | --validation_prompt "a diamond tower in the middle of a lava lake" \ 123 | --validation_image "examples/vidit_depth/condition_0.png" \ # input depth image 124 | --output_dir "examples/vidit_depth" \ 125 | --width 1024 \ 126 | --height 1024 \ 127 | --load_weight_increasement \ 128 | --variant fp16 129 | ``` 130 | 131 | ## Run with Image Processor 132 | 133 | We also provide a simple image processor to help you automatically convert the image to the control condition, such as canny. 134 | 135 | ```python 136 | # examples/anime_canny/run_with_pp.sh 137 | python run_controlnext.py --pretrained_model_name_or_path "neta-art/neta-xl-2.0" \ 138 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \ 139 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \ 140 | --controlnet_scale 1.0 \ 141 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \ 142 | --validation_prompt "3d style, photorealistic style, 1girl, arknights, amiya (arknights), solo, white background, upper body, looking at viewer, blush, closed mouth, low ponytail, black jacket, hooded jacket, open jacket, hood down, blue neckwear" \ 143 | --negative_prompt "worst quality, abstract, clumsy pose, deformed hand, fused fingers, extra digits, fewer digits, fewer fingers, extra fingers, extra arm, missing arm, extra leg, missing leg, signature, artist name, multi views, disfigured, ugly" \ 144 | --validation_image "examples/anime_canny/image_0.png" \ # input image (not canny) 145 | --validation_image_processor "canny" \ # preprocess `validation_image` to canny condition 146 | --output_dir "examples/anime_canny" \ 147 | --load_weight_increasement 148 | ``` 149 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/anime_canny/condition_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/anime_canny/condition_0.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/anime_canny/eval_img/AAM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/anime_canny/eval_img/AAM.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/anime_canny/eval_img/NetaXLV2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/anime_canny/eval_img/NetaXLV2.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/anime_canny/image_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/anime_canny/image_0.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/anime_canny/run.sh: -------------------------------------------------------------------------------- 1 | python run_controlnext.py --pretrained_model_name_or_path "neta-art/neta-xl-2.0" \ 2 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \ 3 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \ 4 | --controlnet_scale 1.0 \ 5 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \ 6 | --validation_prompt "3d style, photorealistic style, 1girl, arknights, amiya (arknights), solo, white background, upper body, looking at viewer, blush, closed mouth, low ponytail, black jacket, hooded jacket, open jacket, hood down, blue neckwear" \ 7 | --negative_prompt "worst quality, abstract, clumsy pose, deformed hand, fused fingers, extra digits, fewer digits, fewer fingers, extra fingers, extra arm, missing arm, extra leg, missing leg, signature, artist name, multi views, disfigured, ugly" \ 8 | --validation_image "examples/anime_canny/condition_0.jpg" \ 9 | --output_dir "examples/anime_canny" \ 10 | --load_weight_increasement \ 11 | --use_safetensors \ 12 | --variant fp16 13 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/anime_canny/run_with_pp.sh: -------------------------------------------------------------------------------- 1 | python run_controlnext.py --pretrained_model_name_or_path "neta-art/neta-xl-2.0" \ 2 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \ 3 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-anime-canny" \ 4 | --controlnet_scale 1.0 \ 5 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \ 6 | --validation_prompt "3d style, photorealistic style, 1girl, arknights, amiya (arknights), solo, white background, upper body, looking at viewer, blush, closed mouth, low ponytail, black jacket, hooded jacket, open jacket, hood down, blue neckwear" \ 7 | --negative_prompt "worst quality, abstract, clumsy pose, deformed hand, fused fingers, extra digits, fewer digits, fewer fingers, extra fingers, extra arm, missing arm, extra leg, missing leg, signature, artist name, multi views, disfigured, ugly" \ 8 | --validation_image "examples/anime_canny/image_0.jpg" \ 9 | --validation_image_processor "canny" \ 10 | --output_dir "examples/anime_canny" \ 11 | --load_weight_increasement \ 12 | --use_safetensors \ 13 | --variant fp16 14 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/demo/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo1.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/demo/demo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo2.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/demo/demo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo3.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/demo/demo4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo4.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/demo/demo5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/demo/demo5.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/vidit_depth/condition_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/vidit_depth/condition_0.png -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/vidit_depth/eval_img/StableDiffusionXL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/vidit_depth/eval_img/StableDiffusionXL.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/vidit_depth/eval_img/StableDiffusionXL_GlassSculpturesLora.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SDXL/examples/vidit_depth/eval_img/StableDiffusionXL_GlassSculpturesLora.jpg -------------------------------------------------------------------------------- /ControlNeXt-SDXL/examples/vidit_depth/run.sh: -------------------------------------------------------------------------------- 1 | python run_controlnext.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-xl-base-1.0" \ 2 | --unet_model_name_or_path "Eugeoter/controlnext-sdxl-vidit-depth" \ 3 | --controlnet_model_name_or_path "Eugeoter/controlnext-sdxl-vidit-depth" \ 4 | --controlnet_scale 1.0 \ 5 | --vae_model_name_or_path "madebyollin/sdxl-vae-fp16-fix" \ 6 | --validation_prompt "a diamond tower in the middle of a lava lake" \ 7 | --validation_image "examples/vidit_depth/condition_0.png" \ 8 | --output_dir "examples/vidit_depth" \ 9 | --width 1024 \ 10 | --height 1024 \ 11 | --load_weight_increasement \ 12 | --variant fp16 13 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | opencv-python 3 | pillow 4 | numpy 5 | transformers 6 | diffusers 7 | safetensors 8 | peft 9 | einops -------------------------------------------------------------------------------- /ControlNeXt-SDXL/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | def get_extractor(extractor_name): 7 | if extractor_name is None: 8 | return None 9 | if extractor_name not in EXTRACTORS: 10 | raise ValueError(f"Extractor {extractor_name} is not supported.") 11 | return EXTRACTORS[extractor_name] 12 | 13 | 14 | def canny_extractor(image: Image.Image, threshold1=None, threshold2=None) -> Image.Image: 15 | image = np.array(image) 16 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 17 | v = np.median(gray) 18 | 19 | sigma = 0.33 20 | threshold1 = threshold1 or int(max(0, (1.0 - sigma) * v)) 21 | threshold2 = threshold2 or int(min(255, (1.0 + sigma) * v)) 22 | 23 | edges = cv2.Canny(gray, threshold1, threshold2) 24 | edges = Image.fromarray(edges).convert("RGB") 25 | return edges 26 | 27 | 28 | def depth_extractor(image: Image.Image): 29 | raise NotImplementedError("Depth extractor is not implemented yet.") 30 | 31 | 32 | def pose_extractor(image: Image.Image): 33 | raise NotImplementedError("Pose extractor is not implemented yet.") 34 | 35 | 36 | EXTRACTORS = { 37 | "canny": canny_extractor, 38 | } 39 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import torch 4 | from diffusers import UniPCMultistepScheduler, AutoencoderKL, ControlNetModel 5 | from safetensors.torch import load_file 6 | from pipeline.pipeline_controlnext import StableDiffusionXLControlNeXtPipeline 7 | from models.unet import UNet2DConditionModel 8 | from models.controlnet import ControlNetModel 9 | from . import utils 10 | 11 | UNET_CONFIG = { 12 | "act_fn": "silu", 13 | "addition_embed_type": "text_time", 14 | "addition_embed_type_num_heads": 64, 15 | "addition_time_embed_dim": 256, 16 | "attention_head_dim": [ 17 | 5, 18 | 10, 19 | 20 20 | ], 21 | "block_out_channels": [ 22 | 320, 23 | 640, 24 | 1280 25 | ], 26 | "center_input_sample": False, 27 | "class_embed_type": None, 28 | "class_embeddings_concat": False, 29 | "conv_in_kernel": 3, 30 | "conv_out_kernel": 3, 31 | "cross_attention_dim": 2048, 32 | "cross_attention_norm": None, 33 | "down_block_types": [ 34 | "DownBlock2D", 35 | "CrossAttnDownBlock2D", 36 | "CrossAttnDownBlock2D" 37 | ], 38 | "downsample_padding": 1, 39 | "dual_cross_attention": False, 40 | "encoder_hid_dim": None, 41 | "encoder_hid_dim_type": None, 42 | "flip_sin_to_cos": True, 43 | "freq_shift": 0, 44 | "in_channels": 4, 45 | "layers_per_block": 2, 46 | "mid_block_only_cross_attention": None, 47 | "mid_block_scale_factor": 1, 48 | "mid_block_type": "UNetMidBlock2DCrossAttn", 49 | "norm_eps": 1e-05, 50 | "norm_num_groups": 32, 51 | "num_attention_heads": None, 52 | "num_class_embeds": None, 53 | "only_cross_attention": False, 54 | "out_channels": 4, 55 | "projection_class_embeddings_input_dim": 2816, 56 | "resnet_out_scale_factor": 1.0, 57 | "resnet_skip_time_act": False, 58 | "resnet_time_scale_shift": "default", 59 | "sample_size": 128, 60 | "time_cond_proj_dim": None, 61 | "time_embedding_act_fn": None, 62 | "time_embedding_dim": None, 63 | "time_embedding_type": "positional", 64 | "timestep_post_act": None, 65 | "transformer_layers_per_block": [ 66 | 1, 67 | 2, 68 | 10 69 | ], 70 | "up_block_types": [ 71 | "CrossAttnUpBlock2D", 72 | "CrossAttnUpBlock2D", 73 | "UpBlock2D" 74 | ], 75 | "upcast_attention": None, 76 | "use_linear_projection": True 77 | } 78 | 79 | CONTROLNET_CONFIG = { 80 | 'in_channels': [128, 128], 81 | 'out_channels': [128, 256], 82 | 'groups': [4, 8], 83 | 'time_embed_dim': 256, 84 | 'final_out_channels': 320, 85 | '_use_default_values': ['time_embed_dim', 'groups', 'in_channels', 'final_out_channels', 'out_channels'] 86 | } 87 | 88 | 89 | def get_pipeline( 90 | pretrained_model_name_or_path, 91 | unet_model_name_or_path, 92 | controlnet_model_name_or_path, 93 | vae_model_name_or_path=None, 94 | lora_path=None, 95 | load_weight_increasement=False, 96 | enable_xformers_memory_efficient_attention=False, 97 | revision=None, 98 | variant=None, 99 | hf_cache_dir=None, 100 | use_safetensors=True, 101 | device=None, 102 | ): 103 | pipeline_init_kwargs = {} 104 | 105 | print(f"loading unet from {pretrained_model_name_or_path}") 106 | if os.path.isfile(pretrained_model_name_or_path): 107 | # load unet from local checkpoint 108 | unet_sd = load_file(pretrained_model_name_or_path) if pretrained_model_name_or_path.endswith(".safetensors") else torch.load(pretrained_model_name_or_path) 109 | unet_sd = utils.extract_unet_state_dict(unet_sd) 110 | unet_sd = utils.convert_sdxl_unet_state_dict_to_diffusers(unet_sd) 111 | unet = UNet2DConditionModel.from_config(UNET_CONFIG) 112 | unet.load_state_dict(unet_sd, strict=True) 113 | else: 114 | unet = UNet2DConditionModel.from_pretrained( 115 | pretrained_model_name_or_path, 116 | cache_dir=hf_cache_dir, 117 | variant=variant, 118 | torch_dtype=torch.float16, 119 | use_safetensors=use_safetensors, 120 | subfolder="unet", 121 | ) 122 | unet = unet.to(dtype=torch.float16) 123 | pipeline_init_kwargs["unet"] = unet 124 | 125 | if vae_model_name_or_path is not None: 126 | print(f"loading vae from {vae_model_name_or_path}") 127 | vae = AutoencoderKL.from_pretrained(vae_model_name_or_path, cache_dir=hf_cache_dir, torch_dtype=torch.float16).to(device) 128 | pipeline_init_kwargs["vae"] = vae 129 | 130 | if controlnet_model_name_or_path is not None: 131 | pipeline_init_kwargs["controlnet"] = ControlNetModel.from_config(CONTROLNET_CONFIG).to(device, dtype=torch.float32) # init 132 | 133 | print(f"loading pipeline from {pretrained_model_name_or_path}") 134 | if os.path.isfile(pretrained_model_name_or_path): 135 | pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_single_file( 136 | pretrained_model_name_or_path, 137 | use_safetensors=pretrained_model_name_or_path.endswith(".safetensors"), 138 | local_files_only=True, 139 | cache_dir=hf_cache_dir, 140 | **pipeline_init_kwargs, 141 | ) 142 | else: 143 | pipeline: StableDiffusionXLControlNeXtPipeline = StableDiffusionXLControlNeXtPipeline.from_pretrained( 144 | pretrained_model_name_or_path, 145 | revision=revision, 146 | variant=variant, 147 | use_safetensors=use_safetensors, 148 | cache_dir=hf_cache_dir, 149 | **pipeline_init_kwargs, 150 | ) 151 | 152 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) 153 | if unet_model_name_or_path is not None: 154 | print(f"loading controlnext unet from {unet_model_name_or_path}") 155 | pipeline.load_controlnext_unet_weights( 156 | unet_model_name_or_path, 157 | load_weight_increasement=load_weight_increasement, 158 | use_safetensors=True, 159 | torch_dtype=torch.float16, 160 | cache_dir=hf_cache_dir, 161 | ) 162 | if controlnet_model_name_or_path is not None: 163 | print(f"loading controlnext controlnet from {controlnet_model_name_or_path}") 164 | pipeline.load_controlnext_controlnet_weights( 165 | controlnet_model_name_or_path, 166 | use_safetensors=True, 167 | torch_dtype=torch.float32, 168 | cache_dir=hf_cache_dir, 169 | ) 170 | pipeline.set_progress_bar_config() 171 | pipeline = pipeline.to(device, dtype=torch.float16) 172 | 173 | if lora_path is not None: 174 | pipeline.load_lora_weights(lora_path) 175 | if enable_xformers_memory_efficient_attention: 176 | pipeline.enable_xformers_memory_efficient_attention() 177 | 178 | gc.collect() 179 | if str(device) == 'cuda' and torch.cuda.is_available(): 180 | torch.cuda.empty_cache() 181 | 182 | return pipeline 183 | 184 | 185 | def get_scheduler( 186 | scheduler_name, 187 | scheduler_config, 188 | ): 189 | if scheduler_name == 'Euler A': 190 | from diffusers.schedulers import EulerAncestralDiscreteScheduler 191 | scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) 192 | elif scheduler_name == 'UniPC': 193 | from diffusers.schedulers import UniPCMultistepScheduler 194 | scheduler = UniPCMultistepScheduler.from_config(scheduler_config) 195 | elif scheduler_name == 'Euler': 196 | from diffusers.schedulers import EulerDiscreteScheduler 197 | scheduler = EulerDiscreteScheduler.from_config(scheduler_config) 198 | elif scheduler_name == 'DDIM': 199 | from diffusers.schedulers import DDIMScheduler 200 | scheduler = DDIMScheduler.from_config(scheduler_config) 201 | elif scheduler_name == 'DDPM': 202 | from diffusers.schedulers import DDPMScheduler 203 | scheduler = DDPMScheduler.from_config(scheduler_config) 204 | else: 205 | raise ValueError(f"Unknown scheduler: {scheduler_name}") 206 | return scheduler 207 | -------------------------------------------------------------------------------- /ControlNeXt-SDXL/utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, Union, Optional 3 | 4 | 5 | def make_unet_conversion_map(): 6 | unet_conversion_map_layer = [] 7 | 8 | for i in range(3): # num_blocks is 3 in sdxl 9 | # loop over downblocks/upblocks 10 | for j in range(2): 11 | # loop over resnets/attentions for downblocks 12 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 13 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 14 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 15 | 16 | if i < 3: 17 | # no attention layers in down_blocks.3 18 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 19 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 20 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 21 | 22 | for j in range(3): 23 | # loop over resnets/attentions for upblocks 24 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 25 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 26 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 27 | 28 | # if i > 0: commentout for sdxl 29 | # no attention layers in up_blocks.0 30 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 31 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 32 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 33 | 34 | if i < 3: 35 | # no downsample in down_blocks.3 36 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 37 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 38 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 39 | 40 | # no upsample in up_blocks.3 41 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 42 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl 43 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 44 | 45 | hf_mid_atn_prefix = "mid_block.attentions.0." 46 | sd_mid_atn_prefix = "middle_block.1." 47 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 48 | 49 | for j in range(2): 50 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 51 | sd_mid_res_prefix = f"middle_block.{2*j}." 52 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 53 | 54 | unet_conversion_map_resnet = [ 55 | # (stable-diffusion, HF Diffusers) 56 | ("in_layers.0.", "norm1."), 57 | ("in_layers.2.", "conv1."), 58 | ("out_layers.0.", "norm2."), 59 | ("out_layers.3.", "conv2."), 60 | ("emb_layers.1.", "time_emb_proj."), 61 | ("skip_connection.", "conv_shortcut."), 62 | ] 63 | 64 | unet_conversion_map = [] 65 | for sd, hf in unet_conversion_map_layer: 66 | if "resnets" in hf: 67 | for sd_res, hf_res in unet_conversion_map_resnet: 68 | unet_conversion_map.append((sd + sd_res, hf + hf_res)) 69 | else: 70 | unet_conversion_map.append((sd, hf)) 71 | 72 | for j in range(2): 73 | hf_time_embed_prefix = f"time_embedding.linear_{j+1}." 74 | sd_time_embed_prefix = f"time_embed.{j*2}." 75 | unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) 76 | 77 | for j in range(2): 78 | hf_label_embed_prefix = f"add_embedding.linear_{j+1}." 79 | sd_label_embed_prefix = f"label_emb.0.{j*2}." 80 | unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) 81 | 82 | unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) 83 | unet_conversion_map.append(("out.0.", "conv_norm_out.")) 84 | unet_conversion_map.append(("out.2.", "conv_out.")) 85 | 86 | return unet_conversion_map 87 | 88 | 89 | def convert_unet_state_dict(src_sd, conversion_map): 90 | converted_sd = {} 91 | for src_key, value in src_sd.items(): 92 | src_key_fragments = src_key.split(".")[:-1] # remove weight/bias 93 | while len(src_key_fragments) > 0: 94 | src_key_prefix = ".".join(src_key_fragments) + "." 95 | if src_key_prefix in conversion_map: 96 | converted_prefix = conversion_map[src_key_prefix] 97 | converted_key = converted_prefix + src_key[len(src_key_prefix):] 98 | converted_sd[converted_key] = value 99 | break 100 | src_key_fragments.pop(-1) 101 | assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" 102 | 103 | return converted_sd 104 | 105 | 106 | def convert_sdxl_unet_state_dict_to_diffusers(sd): 107 | unet_conversion_map = make_unet_conversion_map() 108 | 109 | conversion_dict = {sd: hf for sd, hf in unet_conversion_map} 110 | return convert_unet_state_dict(sd, conversion_dict) 111 | 112 | 113 | def extract_unet_state_dict(state_dict): 114 | unet_sd = {} 115 | UNET_KEY_PREFIX = "model.diffusion_model." 116 | for k, v in state_dict.items(): 117 | if k.startswith(UNET_KEY_PREFIX): 118 | unet_sd[k[len(UNET_KEY_PREFIX):]] = v 119 | return unet_sd 120 | 121 | 122 | def log_model_info(model, name): 123 | sd = model.state_dict() if hasattr(model, "state_dict") else model 124 | print( 125 | f"{name}:", 126 | f" number of parameters: {sum(p.numel() for p in sd.values())}", 127 | f" dtype: {sd[next(iter(sd))].dtype}", 128 | sep='\n' 129 | ) 130 | 131 | 132 | def around_reso(img_w, img_h, reso: Union[Tuple[int, int], int], divisible: Optional[int] = None, max_width=None, max_height=None) -> Tuple[int, int]: 133 | r""" 134 | w*h = reso*reso 135 | w/h = img_w/img_h 136 | => w = img_ar*h 137 | => img_ar*h^2 = reso 138 | => h = sqrt(reso / img_ar) 139 | """ 140 | reso = reso if isinstance(reso, tuple) else (reso, reso) 141 | divisible = divisible or 1 142 | if img_w * img_h <= reso[0] * reso[1] and (not max_width or img_w <= max_width) and (not max_height or img_h <= max_height) and img_w % divisible == 0 and img_h % divisible == 0: 143 | return (img_w, img_h) 144 | img_ar = img_w / img_h 145 | around_h = math.sqrt(reso[0]*reso[1] / img_ar) 146 | around_w = img_ar * around_h // divisible * divisible 147 | if max_width and around_w > max_width: 148 | around_h = around_h * max_width // around_w 149 | around_w = max_width 150 | elif max_height and around_h > max_height: 151 | around_w = around_w * max_height // around_h 152 | around_h = max_height 153 | around_h = min(around_h, max_height) if max_height else around_h 154 | around_w = min(around_w, max_width) if max_width else around_w 155 | around_h = int(around_h // divisible * divisible) 156 | around_w = int(around_w // divisible * divisible) 157 | return (around_w, around_h) 158 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/README.md: -------------------------------------------------------------------------------- 1 | # 🌀 ControlNeXt-SVD-v2-Training 2 | 3 | # Important 4 | 5 | I found that sometimes, when I change the version of dependencies, the training may not converge at all. I haven't identified the reason yet, but I've listed all our dependencies in the [requirements.txt](./requirements.txt) file. It's a bit detailed, but you can focus on the key dependencies like `torch`, `deepspeed`, `diffusers`, `accelerate`... When issues arise, checking these first may help. (We use: `diffusers==0.25.0`) 6 | 7 | If you find the differences for the training and inference scripts, such as the `import path`, please refer to the training script! 8 | 9 | ## Main 10 | 11 | Due to privacy concerns, we are unable to release certain resources, such as the training data and the SD3-based model. However, we are committed to sharing as much as possible. If you find this repository helpful, please consider giving us a star or citing our work! 12 | 13 | The training scripts are intended for users with a basic understanding of `Python` and `Diffusers`. Therefore, we will not provide every detail. If you have any questions, please refer to the code first. Thank you! If you encounter any bugs, please contact us and let us know. 14 | 15 | ## Experiences 16 | 17 | We share more training experiences in the [Issue](https://github.com/dvlab-research/ControlNeXt/issues/14#issuecomment-2290450333) and [There](../experiences.md). 18 | We spent a lot of time to find these. Now share with all of you. May these will help you! 19 | 20 | 21 | 22 | ## Training script 23 | 24 | ```bash 25 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file ./deepspeed.yaml train_svd.py \ 26 | --pretrained_model_name_or_path=stabilityai/stable-video-diffusion-img2vid-xt-1-1 \ 27 | --output_dir= $PATH_TO_THE_SAVE_DIR \ 28 | --dataset_type="ubc" \ 29 | --meta_info_path=$PATH_TO_THE_META_INFO_FILE_FOR_DATASET \ 30 | --validation_image_folder=$PATH_TO_THE_GROUND_TRUTH_DIR_FOR_VALIDATION \ 31 | --validation_control_folder=$PATH_TO_THE_POSE_DIR_FOR_VALIDATION \ 32 | --validation_image=$PATH_TO_THE_REFERENCE_IMAGE_FILE_FOR_VALIDATION \ 33 | --width=576 \ 34 | --height=1024 \ 35 | --lr_warmup_steps 500 \ 36 | --sample_n_frames 21 \ 37 | --interval_frame 3 \ 38 | --learning_rate=1e-5 \ 39 | --per_gpu_batch_size=1 \ 40 | --num_train_epochs=6000 \ 41 | --mixed_precision="bf16" \ 42 | --gradient_accumulation_steps=1 \ 43 | --checkpointing_steps=2000 \ 44 | --validation_steps=500 \ 45 | --gradient_checkpointing \ 46 | --checkpoints_total_limit 4 47 | 48 | # For Resume 49 | --controlnet_model_name_or_path $PATH_TO_THE_CONTROLNEXT_WEIGHT 50 | --unet_model_name_or_path $PATH_TO_THE_UNET_WEIGHT 51 | ``` 52 | 53 | We set `--num_train_epochs=6000` to ensure no stopped training, but you can stop the process at any point when you believe the results are satisfactory. 54 | 55 | ## Training validation 56 | 57 | Please compile the data for validation like: 58 | ``` 59 | ├───ControlNeXt-SVD-v2-Training 60 | └─── ... 61 | ├───validation 62 | | ├───ground_truth 63 | | | ├───0.png 64 | | | ├─── ... 65 | | | └───13.png 66 | | | 67 | | └───pose 68 | | | ├───0.png 69 | | | ├─── ... 70 | | | └───13.png 71 | | | 72 | | └───reference_image.png 73 | | 74 | └─── ... 75 | ``` 76 | 77 | And then replace the `path` to: 78 | ```bash 79 | --validation_image_folder=$PATH_TO_THE_GROUND_TRUTH_DIR_FOR_VALIDATION \ 80 | --validation_control_folder=$PATH_TO_THE_POSE_DIR_FOR_VALIDATION \ 81 | --validation_image=$PATH_TO_THE_REFERENCE_IMAGE_FILE_FOR_VALIDATION \ 82 | ``` 83 | 84 | ## DeepSpeed 85 | 86 | When using `DeepSpeed` for training, you’ll notice the use of a `DeepSpeedWrapperModel` in [training script](train_svd.py#L837). This wrapper is necessary because DeepSpeed supports only a single model for parallel training, allowing us to encapsulate different modules, including ControlNet and UNet. 87 | 88 | To perform inference with your trained weights, follow these steps: 89 | 90 | 1. Convert the generated weights to a `.bin` file using `zero_to_fp32.py` generated by DeepSpeed (Under the generated weight directory). This will create a file named `pytorch_model.bin`. 91 | 2. Utilize the script [unwrap_deepspeed.py](utils/train_svd.py) to separate the modules into distinct dictionaries for ControlNet and UNet. 92 | 3. Provide the paths to these weights in your inference script. 93 | 94 | ## Meta info 95 | 96 | Please construct the training dataset and provide a list of the data entries in a .json file. We give an example in `meta_info_example/meta_info.json` (the data list) and `meta_info_example/meta_info/1.json`(Detailed meta information for each single video recoarding the position and score): 97 | 98 | `meta_info.json` 99 | ```json 100 | [ 101 | { 102 | "video_path": "PATH_TO_THE_SOURCE_VIDEO", 103 | "guide_path": "PATH_TO_THE_CORESEPONDING_POSE_VIDEO", 104 | "meta_info": "PATH_TO_THE_JSON_FILE_RECORD_THE_DETAILED_DETECTION_RESULTS(we give an example in meta_info/1.json)" 105 | } 106 | ... 107 | ] 108 | ``` 109 | 110 | ## GPU memory 111 | 112 | It requires substantial memory for training, as we use a high resolution and long frame batches to achieve optimal performance. However, you can implement certain techniques to reduce memory consumption, although they may result in a trade-off with performance. 113 | 114 | > 1. Adopt bf16 and fp16 (we have already implemented this). 115 | > 2. Use DeepSpeed and distributed training across multiple machines. 116 | > 3. Reduce the resolution by set `--width=576 --height=1024 `, such as `512*768` 117 | > 4. Reduce the `--sample_n_frames` 118 | 119 | 120 | ### If you find this work helpful, please consider citing: 121 | ``` 122 | @article{peng2024controlnext, 123 | title={ControlNeXt: Powerful and Efficient Control for Image and Video Generation}, 124 | author={Peng, Bohao and Wang, Jian and Zhang, Yuechen and Li, Wenbo and Yang, Ming-Chang and Jia, Jiaya}, 125 | journal={arXiv preprint arXiv:2408.06070}, 126 | year={2024} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/deepspeed.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | fsdp_config: {} 11 | machine_rank: 0 12 | main_process_ip: null 13 | main_process_port: null 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 8 18 | use_cpu: false -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/meta_info_example/meta_info.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "video_path": "PATH_TO_THE_SOURCE_VIDEO", 4 | "guide_path": "PATH_TO_THE_CORESEPONDING_POSE_VIDEO", 5 | "meta_info": "PATH_TO_THE_JSON_FILE_RECORD_THE_DETAILED_DETECTION_RESULTS(we give an example in meta_info/1.json)" 6 | } 7 | ] -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/models/controlnext_vid_svd.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D 10 | 11 | 12 | class ControlNeXtSDVModel(ModelMixin, ConfigMixin): 13 | _supports_gradient_checkpointing = True 14 | 15 | @register_to_config 16 | def __init__( 17 | self, 18 | time_embed_dim = 256, 19 | in_channels = [128, 128], 20 | out_channels = [128, 256], 21 | groups = [4, 8] 22 | ): 23 | super().__init__() 24 | 25 | self.time_proj = Timesteps(128, True, downscale_freq_shift=0) 26 | self.time_embedding = TimestepEmbedding(128, time_embed_dim) 27 | self.embedding = nn.Sequential( 28 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 29 | nn.GroupNorm(2, 64), 30 | nn.ReLU(), 31 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 32 | nn.GroupNorm(2, 64), 33 | nn.ReLU(), 34 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 35 | nn.GroupNorm(2, 128), 36 | nn.ReLU(), 37 | ) 38 | 39 | self.down_res = nn.ModuleList() 40 | self.down_sample = nn.ModuleList() 41 | for i in range(len(in_channels)): 42 | self.down_res.append( 43 | ResnetBlock2D( 44 | in_channels=in_channels[i], 45 | out_channels=out_channels[i], 46 | temb_channels=time_embed_dim, 47 | groups=groups[i] 48 | ), 49 | ) 50 | self.down_sample.append( 51 | Downsample2D( 52 | out_channels[i], 53 | use_conv=True, 54 | out_channels=out_channels[i], 55 | padding=1, 56 | name="op", 57 | ) 58 | ) 59 | 60 | self.mid_convs = nn.ModuleList() 61 | self.mid_convs.append(nn.Sequential( 62 | nn.Conv2d( 63 | in_channels=out_channels[-1], 64 | out_channels=out_channels[-1], 65 | kernel_size=3, 66 | stride=1, 67 | padding=1 68 | ), 69 | nn.ReLU(), 70 | nn.GroupNorm(8, out_channels[-1]), 71 | nn.Conv2d( 72 | in_channels=out_channels[-1], 73 | out_channels=out_channels[-1], 74 | kernel_size=3, 75 | stride=1, 76 | padding=1 77 | ), 78 | nn.GroupNorm(8, out_channels[-1]), 79 | )) 80 | self.mid_convs.append( 81 | nn.Conv2d( 82 | in_channels=out_channels[-1], 83 | out_channels=320, 84 | kernel_size=1, 85 | stride=1, 86 | )) 87 | 88 | self.scale = 1. 89 | 90 | def _set_gradient_checkpointing(self, module, value=False): 91 | if hasattr(module, "gradient_checkpointing"): 92 | module.gradient_checkpointing = value 93 | 94 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 95 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 96 | """ 97 | Sets the attention processor to use [feed forward 98 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 99 | 100 | Parameters: 101 | chunk_size (`int`, *optional*): 102 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 103 | over each tensor of dim=`dim`. 104 | dim (`int`, *optional*, defaults to `0`): 105 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 106 | or dim=1 (sequence length). 107 | """ 108 | if dim not in [0, 1]: 109 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 110 | 111 | # By default chunk size is 1 112 | chunk_size = chunk_size or 1 113 | 114 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 115 | if hasattr(module, "set_chunk_feed_forward"): 116 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 117 | 118 | for child in module.children(): 119 | fn_recursive_feed_forward(child, chunk_size, dim) 120 | 121 | for module in self.children(): 122 | fn_recursive_feed_forward(module, chunk_size, dim) 123 | 124 | def forward( 125 | self, 126 | sample: torch.FloatTensor, 127 | timestep: Union[torch.Tensor, float, int], 128 | ): 129 | 130 | timesteps = timestep 131 | if not torch.is_tensor(timesteps): 132 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 133 | # This would be a good case for the `match` statement (Python 3.10+) 134 | is_mps = sample.device.type == "mps" 135 | if isinstance(timestep, float): 136 | dtype = torch.float32 if is_mps else torch.float64 137 | else: 138 | dtype = torch.int32 if is_mps else torch.int64 139 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 140 | elif len(timesteps.shape) == 0: 141 | timesteps = timesteps[None].to(sample.device) 142 | 143 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 144 | batch_size, num_frames = sample.shape[:2] 145 | timesteps = timesteps.expand(batch_size) 146 | 147 | t_emb = self.time_proj(timesteps) 148 | 149 | # `Timesteps` does not contain any weights and will always return f32 tensors 150 | # but time_embedding might actually be running in fp16. so we need to cast here. 151 | # there might be better ways to encapsulate this. 152 | t_emb = t_emb.to(dtype=sample.dtype) 153 | 154 | emb_batch = self.time_embedding(t_emb) 155 | 156 | # Flatten the batch and frames dimensions 157 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 158 | sample = sample.flatten(0, 1) 159 | # Repeat the embeddings num_video_frames times 160 | # emb: [batch, channels] -> [batch * frames, channels] 161 | emb = emb_batch.repeat_interleave(num_frames, dim=0) 162 | 163 | sample = self.embedding(sample) 164 | 165 | for res, downsample in zip(self.down_res, self.down_sample): 166 | sample = res(sample, emb) 167 | sample = downsample(sample, emb) 168 | 169 | sample = self.mid_convs[0](sample) + sample 170 | sample = self.mid_convs[1](sample) 171 | 172 | return { 173 | 'output': sample, 174 | 'scale': self.scale, 175 | } 176 | 177 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.31.0 3 | addict==2.4.0 4 | aiofiles==23.2.1 5 | aiohttp==3.9.5 6 | aiosignal==1.3.1 7 | albumentations==1.3.1 8 | altair==5.3.0 9 | annotated-types==0.7.0 10 | antlr4-python3-runtime==4.8 11 | anyio==4.4.0 12 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work 13 | async-timeout==4.0.3 14 | attrs==23.2.0 15 | av==12.0.0 16 | azureml==0.2.7 17 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 18 | basicsr==1.4.2 19 | blessed==1.20.0 20 | boto3==1.34.123 21 | botocore==1.34.123 22 | braceexpand==0.1.7 23 | cachetools==5.3.3 24 | certifi==2024.6.2 25 | cffi==1.16.0 26 | chardet==5.2.0 27 | charset-normalizer==3.3.2 28 | chumpy==0.70 29 | clean-fid==0.1.35 30 | click==8.1.7 31 | clip-anytorch==2.6.0 32 | cmake==3.29.3 33 | coloredlogs==15.0.1 34 | colorlog==6.8.2 35 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work 36 | contourpy==1.1.1 37 | cycler==0.12.1 38 | datasets==2.19.1 39 | dctorch==0.1.2 40 | debugpy @ file:///croot/debugpy_1690905042057/work 41 | decorator==4.4.2 42 | decord==0.6.0 43 | deepspeed==0.14.5 44 | -e git+https://github.com/huggingface/diffusers.git@983dec3bf787c064ed57f2621c9b7375d443f746#egg=diffusers 45 | dill==0.3.8 46 | dnspython==2.6.1 47 | docker-pycreds==0.4.0 48 | editor==1.6.6 49 | einops==0.8.0 50 | einops-exts==0.0.4 51 | email_validator==2.1.1 52 | embreex==2.17.7.post4 53 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work 54 | exceptiongroup==1.2.1 55 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work 56 | facexlib==0.3.0 57 | fairscale==0.4.13 58 | fastapi==0.111.0 59 | fastapi-cli==0.0.4 60 | ffmpy==0.3.2 61 | filelock==3.15.1 62 | filterpy==1.4.5 63 | flatbuffers==24.3.25 64 | fonttools==4.52.4 65 | frozenlist==1.4.1 66 | fsspec==2024.6.0 67 | ftfy==6.2.0 68 | future==1.0.0 69 | gitdb==4.0.11 70 | GitPython==3.1.43 71 | google-auth==2.29.0 72 | google-auth-oauthlib==1.0.0 73 | gradio==4.31.5 74 | gradio_client==0.16.4 75 | gradio_imageslider==0.0.18 76 | grpcio==1.64.0 77 | h11==0.14.0 78 | hjson==3.1.0 79 | httpcore==1.0.5 80 | httptools==0.6.1 81 | httpx==0.27.0 82 | huggingface-hub==0.23.3 83 | humanfriendly==10.0 84 | idna==3.7 85 | imageio==2.34.1 86 | imageio-ffmpeg==0.4.9 87 | importlib_metadata==7.1.0 88 | importlib_resources==6.4.0 89 | inquirer==3.3.0 90 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work 91 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work 92 | jax==0.4.13 93 | jaxlib==0.4.13 94 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work 95 | Jinja2==3.1.4 96 | jmespath==1.0.1 97 | joblib==1.4.2 98 | jsonmerge==1.9.2 99 | jsonschema==4.22.0 100 | jsonschema-specifications==2023.12.1 101 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work 102 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257397447/work 103 | k-diffusion==0.1.1.post1 104 | kiwisolver==1.4.5 105 | kornia==0.7.2 106 | kornia_rs==0.1.3 107 | lazy_loader==0.4 108 | lightning-utilities==0.11.2 109 | lit==18.1.6 110 | llvmlite==0.41.1 111 | lmdb==1.4.1 112 | lxml==5.2.2 113 | manopth @ file:///home/llm/bhpeng/generation/HandRefiner/MeshGraphormer/manopth 114 | mapbox-earcut==1.0.1 115 | Markdown==3.6 116 | markdown-it-py==3.0.0 117 | MarkupSafe==2.1.5 118 | matplotlib==3.7.5 119 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work 120 | mdurl==0.1.2 121 | mediapipe==0.10.0 122 | ml-dtypes==0.2.0 123 | moviepy==1.0.3 124 | mpmath==1.3.0 125 | multidict==6.0.5 126 | multiprocess==0.70.16 127 | mypy-extensions==1.0.0 128 | nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work 129 | networkx==3.1 130 | ninja==1.11.1.1 131 | numba==0.58.1 132 | numpy==1.24.4 133 | nvidia-cublas-cu11==11.10.3.66 134 | nvidia-cublas-cu12==12.1.3.1 135 | nvidia-cuda-cupti-cu11==11.7.101 136 | nvidia-cuda-cupti-cu12==12.1.105 137 | nvidia-cuda-nvrtc-cu11==11.7.99 138 | nvidia-cuda-nvrtc-cu12==12.1.105 139 | nvidia-cuda-runtime-cu11==11.7.99 140 | nvidia-cuda-runtime-cu12==12.1.105 141 | nvidia-cudnn-cu11==8.5.0.96 142 | nvidia-cudnn-cu12==8.9.2.26 143 | nvidia-cufft-cu11==10.9.0.58 144 | nvidia-cufft-cu12==11.0.2.54 145 | nvidia-curand-cu11==10.2.10.91 146 | nvidia-curand-cu12==10.3.2.106 147 | nvidia-cusolver-cu11==11.4.0.1 148 | nvidia-cusolver-cu12==11.4.5.107 149 | nvidia-cusparse-cu11==11.7.4.91 150 | nvidia-cusparse-cu12==12.1.0.106 151 | nvidia-ml-py==12.560.30 152 | nvidia-nccl-cu11==2.14.3 153 | nvidia-nccl-cu12==2.20.5 154 | nvidia-nvjitlink-cu12==12.5.40 155 | nvidia-nvtx-cu11==11.7.91 156 | nvidia-nvtx-cu12==12.1.105 157 | oauthlib==3.2.2 158 | omegaconf==2.1.1 159 | onnxruntime-gpu==1.17.1 160 | open-clip-torch==2.22.0 161 | openai-clip==1.0.1 162 | opencv-contrib-python==4.10.0.82 163 | opencv-python==4.9.0.80 164 | opencv-python-headless==4.10.0.82 165 | opt-einsum==3.3.0 166 | orjson==3.10.3 167 | packaging==24.1 168 | pandas==2.0.0 169 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work 170 | peft==0.11.1 171 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work 172 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 173 | Pillow==9.5.0 174 | pkgutil_resolve_name==1.3.10 175 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work 176 | proglog==0.1.10 177 | prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work 178 | protobuf==5.27.2 179 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work 180 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 181 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 182 | py-cpuinfo==9.0.0 183 | pyarrow==16.1.0 184 | pyarrow-hotfix==0.6 185 | pyasn1==0.6.0 186 | pyasn1_modules==0.4.0 187 | pycollada==0.8 188 | pycparser==2.22 189 | pydantic==2.7.1 190 | pydantic_core==2.18.2 191 | pydub==0.25.1 192 | pygifsicle==1.0.7 193 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work 194 | pyparsing==3.1.2 195 | pyre-extensions==0.0.29 196 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work 197 | python-dotenv==1.0.1 198 | python-multipart==0.0.9 199 | pytorch-lightning==1.9.3 200 | pytorch-pretrained-bert==0.6.2 201 | pytz==2024.1 202 | PyWavelets==1.4.1 203 | PyYAML==6.0.1 204 | pyzmq @ file:///croot/pyzmq_1705605076900/work 205 | qudida==0.0.4 206 | readchar==4.1.0 207 | referencing==0.35.1 208 | regex==2024.5.15 209 | requests==2.32.3 210 | requests-oauthlib==2.0.0 211 | rich==13.7.1 212 | rpds-py==0.18.1 213 | rsa==4.9 214 | Rtree==1.2.0 215 | ruff==0.4.5 216 | runs==1.2.2 217 | s3transfer==0.10.1 218 | safetensors==0.4.3 219 | scikit-image==0.21.0 220 | scikit-learn==1.3.2 221 | scipy==1.10.1 222 | semantic-version==2.10.0 223 | sentencepiece==0.2.0 224 | sentry-sdk==2.5.1 225 | setproctitle==1.3.3 226 | shapely==2.0.4 227 | shellingham==1.5.4 228 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 229 | smmap==5.0.1 230 | sniffio==1.3.1 231 | sounddevice==0.4.7 232 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work 233 | starlette==0.37.2 234 | support-developer==1.0.5 235 | svg.path==6.3 236 | sympy==1.12.1 237 | tb-nightly==2.14.0a20230808 238 | tensorboard==2.14.0 239 | tensorboard-data-server==0.7.2 240 | tensorboardX==2.6.2.2 241 | threadpoolctl==3.5.0 242 | tifffile==2023.7.10 243 | timm==0.6.13 244 | tokenizers==0.19.1 245 | tomli==2.0.1 246 | tomlkit==0.12.0 247 | toolz==0.12.1 248 | torch==2.3.1 249 | torchdiffeq==0.2.4 250 | torchmetrics==1.4.0.post0 251 | torchsde==0.2.6 252 | torchvision==0.15.1 253 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827257044/work 254 | tqdm==4.66.4 255 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work 256 | trampoline==0.1.2 257 | transformers==4.41.2 258 | trimesh==3.23.5 259 | triton==2.3.1 260 | typer==0.12.3 261 | typing-inspect==0.9.0 262 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work 263 | tzdata==2024.1 264 | ujson==5.10.0 265 | urllib3==2.2.1 266 | uvicorn==0.30.0 267 | uvloop==0.19.0 268 | wandb==0.17.1 269 | watchfiles==0.22.0 270 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work 271 | webdataset==0.2.86 272 | websockets==11.0.3 273 | Werkzeug==3.0.3 274 | xformers==0.0.26.post1 275 | xmod==1.8.1 276 | xxhash==3.4.1 277 | yacs==0.1.8 278 | yapf==0.40.2 279 | yarl==1.9.4 280 | zipp==3.19.0 281 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/script.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file ./deepspeed.yaml train_svd.py \ 5 | --pretrained_model_name_or_path=stabilityai/stable-video-diffusion-img2vid-xt-1-1 \ 6 | --output_dir= $PATH_TO_THE_SAVE_DIR \ 7 | --dataset_type="ubc" \ 8 | --meta_info_path=$PATH_TO_THE_META_INFO_FILE_FOR_DATASET \ 9 | --validation_image_folder=$PATH_TO_THE_GROUND_TRUTH_DIR_FOR_EVALUATION \ 10 | --validation_control_folder=$PATH_TO_THE_POSE_DIR_FOR_EVALUATION \ 11 | --validation_image=$PATH_TO_THE_REFERENCE_IMAGE_FILE_FOR_EVALUATION \ 12 | --width=576 \ 13 | --height=1024 \ 14 | --lr_warmup_steps 500 \ 15 | --sample_n_frames 14 \ 16 | --interval_frame 3 \ 17 | --learning_rate=1e-5 \ 18 | --per_gpu_batch_size=1 \ 19 | --num_train_epochs=6000 \ 20 | --mixed_precision="bf16" \ 21 | --gradient_accumulation_steps=1 \ 22 | --checkpointing_steps=2000 \ 23 | --validation_steps=500 \ 24 | --gradient_checkpointing \ 25 | --checkpoints_total_limit 4 26 | 27 | # For Resume 28 | --controlnet_model_name_or_path $PATH_TO_THE_CONTROLNEXT_WEIGHT 29 | --unet_model_name_or_path $PATH_TO_THE_UNET_WEIGHT 30 | 31 | 32 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os, io, csv, math, random 2 | import numpy as np 3 | from einops import rearrange 4 | 5 | import torch 6 | from decord import VideoReader 7 | import cv2 8 | 9 | import torchvision.transforms as transforms 10 | from torch.utils.data.dataset import Dataset 11 | from utils.util import zero_rank_print 12 | #from torchvision.io import read_image 13 | from PIL import Image 14 | def pil_image_to_numpy(image): 15 | """Convert a PIL image to a NumPy array.""" 16 | if image.mode != 'RGB': 17 | image = image.convert('RGB') 18 | return np.array(image) 19 | 20 | def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: 21 | """Convert a NumPy image to a PyTorch tensor.""" 22 | if images.ndim == 3: 23 | images = images[..., None] 24 | images = torch.from_numpy(images.transpose(0, 3, 1, 2)) 25 | return images.float() / 255 26 | 27 | 28 | class WebVid10M(Dataset): 29 | def __init__( 30 | self, 31 | csv_path, video_folder,depth_folder,motion_folder, 32 | sample_size=256, sample_stride=4, sample_n_frames=14, 33 | ): 34 | zero_rank_print(f"loading annotations from {csv_path} ...") 35 | with open(csv_path, 'r') as csvfile: 36 | self.dataset = list(csv.DictReader(csvfile)) 37 | self.length = len(self.dataset) 38 | print(f"data scale: {self.length}") 39 | random.shuffle(self.dataset) 40 | self.video_folder = video_folder 41 | self.sample_stride = sample_stride 42 | self.sample_n_frames = sample_n_frames 43 | self.depth_folder = depth_folder 44 | self.motion_values_folder=motion_folder 45 | print("length",len(self.dataset)) 46 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 47 | print("sample size",sample_size) 48 | self.pixel_transforms = transforms.Compose([ 49 | transforms.RandomHorizontalFlip(), 50 | transforms.Resize(sample_size), 51 | transforms.CenterCrop(sample_size), 52 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 53 | ]) 54 | 55 | 56 | 57 | 58 | 59 | def center_crop(self,img): 60 | h, w = img.shape[-2:] # Assuming img shape is [C, H, W] or [B, C, H, W] 61 | min_dim = min(h, w) 62 | top = (h - min_dim) // 2 63 | left = (w - min_dim) // 2 64 | return img[..., top:top+min_dim, left:left+min_dim] 65 | 66 | 67 | def get_batch(self, idx): 68 | def sort_frames(frame_name): 69 | return int(frame_name.split('_')[1].split('.')[0]) 70 | 71 | 72 | 73 | while True: 74 | video_dict = self.dataset[idx] 75 | videoid = video_dict['videoid'] 76 | 77 | preprocessed_dir = os.path.join(self.video_folder, videoid) 78 | depth_folder = os.path.join(self.depth_folder, videoid) 79 | motion_values_file = os.path.join(self.motion_values_folder, videoid, videoid + "_average_motion.txt") 80 | 81 | if not os.path.exists(depth_folder) or not os.path.exists(motion_values_file): 82 | idx = random.randint(0, len(self.dataset) - 1) 83 | continue 84 | 85 | # Sort and limit the number of image and depth files to 14 86 | image_files = sorted(os.listdir(preprocessed_dir), key=sort_frames)[:14] 87 | depth_files = sorted(os.listdir(depth_folder), key=sort_frames)[:14] 88 | 89 | # Check if there are enough frames for both image and depth 90 | if len(image_files) < 14 or len(depth_files) < 14: 91 | idx = random.randint(0, len(self.dataset) - 1) 92 | continue 93 | 94 | # Load image frames 95 | numpy_images = np.array([pil_image_to_numpy(Image.open(os.path.join(preprocessed_dir, img))) for img in image_files]) 96 | pixel_values = numpy_to_pt(numpy_images) 97 | 98 | # Load depth frames 99 | numpy_depth_images = np.array([pil_image_to_numpy(Image.open(os.path.join(depth_folder, df))) for df in depth_files]) 100 | depth_pixel_values = numpy_to_pt(numpy_depth_images) 101 | 102 | # Load motion values 103 | with open(motion_values_file, 'r') as file: 104 | motion_values = float(file.read().strip()) 105 | 106 | return pixel_values, depth_pixel_values, motion_values 107 | 108 | 109 | 110 | 111 | def __len__(self): 112 | return self.length 113 | 114 | def __getitem__(self, idx): 115 | 116 | #while True: 117 | # try: 118 | pixel_values, depth_pixel_values,motion_values = self.get_batch(idx) 119 | # break 120 | # except Exception as e: 121 | # print(e) 122 | # idx = random.randint(0, self.length - 1) 123 | 124 | pixel_values = self.pixel_transforms(pixel_values) 125 | sample = dict(pixel_values=pixel_values, depth_pixel_values=depth_pixel_values,motion_values=motion_values) 126 | return sample 127 | 128 | 129 | 130 | 131 | if __name__ == "__main__": 132 | from utils.util import save_videos_grid 133 | 134 | dataset = WebVid10M( 135 | csv_path="/data/webvid/results_2M_train.csv", 136 | video_folder="/data/webvid/data/videos", 137 | sample_size=256, 138 | sample_stride=4, sample_n_frames=16, 139 | is_image=True, 140 | ) 141 | import pdb 142 | pdb.set_trace() 143 | 144 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) 145 | for idx, batch in enumerate(dataloader): 146 | print(batch["pixel_values"].shape, len(batch["text"])) 147 | # for i in range(batch["pixel_values"].shape[0]): 148 | # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/utils/extract_learned_paras.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import torch 5 | import os 6 | from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNetModel 7 | from safetensors.torch import save_file, load_file 8 | 9 | 10 | """ 11 | python -m utils.extract_learned_paras \ 12 | /home/llm/bhpeng/generation/svd-temporal-controlnet/outputs_sdxt_mid_upblocks/ft_on_duqi_after_hands500_checkpoint-200/unet/unet_fp16.bin \ 13 | /home/llm/bhpeng/generation/svd-temporal-controlnet/outputs_sdxt_mid_upblocks/ft_on_duqi_after_hands500_checkpoint-200/unet/unet_fp16_increase.bin \ 14 | --pretrained_path /home/llm/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid-xt-1-1/snapshots/a423ba0d3e1a94a57ebc68e98691c43104198394 15 | """ 16 | 17 | 18 | if __name__ == "__main__": 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("src_path", 22 | type=str, 23 | help="path to the video") 24 | parser.add_argument("dst_path", 25 | type=str, 26 | help="path to the save_dict") 27 | parser.add_argument("--pretrained_path", 28 | type=str, 29 | default="/home/llm/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid/snapshots/ae8391f7321be9ff8941508123715417da827aa4") 30 | parser.add_argument("--save_as_fp32", 31 | action="store_true",) 32 | parser.add_argument("--save_weight_increase", 33 | action="store_true",) 34 | args = parser.parse_args() 35 | 36 | unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained( 37 | args.pretrained_path, 38 | subfolder="unet", 39 | low_cpu_mem_usage=True, 40 | variant="fp16", 41 | ) 42 | pretrained_state_dict = unet.state_dict() 43 | 44 | if os.path.splitext(args.src_path)[1] == ".bin": 45 | src_state_dict = torch.load(args.src_path) 46 | elif os.path.splitext(args.src_path)[1] == ".safetensors": 47 | src_state_dict = load_file(args.src_path) 48 | 49 | for k in list(src_state_dict.keys()): 50 | src_state_dict[k] = src_state_dict[k].to(pretrained_state_dict[k]) 51 | if torch.allclose(src_state_dict[k], pretrained_state_dict[k]): 52 | src_state_dict.pop(k) 53 | continue 54 | if args.save_weight_increase: 55 | src_state_dict[k] = src_state_dict[k] - pretrained_state_dict[k] 56 | if not args.save_as_fp32: 57 | src_state_dict[k] = src_state_dict[k].half() 58 | 59 | 60 | if os.path.splitext(args.dst_path)[1] == ".bin": 61 | torch.save(src_state_dict, args.dst_path) 62 | elif os.path.splitext(args.dst_path)[1] == ".safetensors": 63 | save_file(src_state_dict, args.dst_path) 64 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/utils/extract_vid2img.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import random 5 | import os 6 | from decord import VideoReader 7 | import cv2 8 | 9 | 10 | """ 11 | python -m utils.extract_vid2img /home/llm/bhpeng/generation/svd-temporal-controlnet/proj/dataset/cropped/v2/users/肚脐小师妹/7296828856386784548.mp4 /home/llm/bhpeng/generation/svd-temporal-controlnet/validation_demo/test/0 12 | """ 13 | 14 | 15 | if __name__ == "__main__": 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("video_path", 19 | type=str, 20 | help="path to the video") 21 | parser.add_argument("save_dict", 22 | type=str, 23 | help="path to the save_dict") 24 | parser.add_argument("--interval_frame", 25 | type=int, 26 | default=2) 27 | parser.add_argument("--sample_n_frames", 28 | type=int, 29 | default=1) 30 | args = parser.parse_args() 31 | 32 | 33 | video_path = args.video_path 34 | pose_path = video_path.replace("/users", "/pose") 35 | 36 | save_video_path = os.path.join(args.save_dict, "rgb") 37 | save_pose_path = os.path.join(args.save_dict, "pose") 38 | 39 | if not os.path.exists(save_video_path): 40 | os.makedirs(save_video_path) 41 | if not os.path.exists(save_pose_path): 42 | os.makedirs(save_pose_path) 43 | 44 | vr = VideoReader(video_path) 45 | length = len(vr) 46 | segment_length = args.interval_frame * args.sample_n_frames 47 | assert length >= segment_length, "Too short video..." 48 | bg_frame_id = random.randint(0, length - segment_length) 49 | frame_ids = list(range(bg_frame_id, bg_frame_id + segment_length, args.interval_frame)) 50 | for idx, fid in enumerate(frame_ids): 51 | frame = vr[fid].asnumpy() 52 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 53 | frame = cv2.imwrite(os.path.join(save_video_path, "{}.png".format(idx)), frame) 54 | 55 | 56 | vr = VideoReader(pose_path) 57 | for idx, fid in enumerate(frame_ids): 58 | frame = vr[fid].asnumpy() 59 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 60 | frame = cv2.imwrite(os.path.join(save_pose_path, "{}.png".format(idx)), frame) 61 | 62 | 63 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/utils/unwrap_deepspeed.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import torch 5 | import os 6 | from collections import OrderedDict 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("checkpoint_dir", 13 | type=str, 14 | help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 15 | args = parser.parse_args() 16 | state_dict = torch.load(os.path.join(args.checkpoint_dir, "pytorch_model.bin")) 17 | 18 | unet_state = OrderedDict() 19 | controlnext_state = OrderedDict() 20 | for name, data in state_dict.items(): 21 | model = name.split('.', 1)[0] 22 | module = name.split('.', 1)[1] 23 | if model == 'unet': 24 | unet_state[module] = data 25 | elif model == 'controlnext': 26 | controlnext_state[module] = data 27 | 28 | for model in ['unet', 'controlnext']: 29 | if not os.path.exists(os.path.join(args.checkpoint_dir, model)): 30 | os.makedirs(os.path.join(args.checkpoint_dir, model)) 31 | 32 | 33 | torch.save(unet_state, os.path.join(args.checkpoint_dir, "unet", "diffusion_pytorch_model.bin")) 34 | torch.save(controlnext_state, os.path.join(args.checkpoint_dir, "controlnext", "diffusion_pytorch_model.bin")) 35 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2-Training/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | 10 | from safetensors import safe_open 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | 14 | 15 | def zero_rank_print(s): 16 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 17 | 18 | 19 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 20 | videos = rearrange(videos, "b c t h w -> t b c h w") 21 | outputs = [] 22 | for x in videos: 23 | x = torchvision.utils.make_grid(x, nrow=n_rows) 24 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 25 | if rescale: 26 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 27 | x = (x * 255).numpy().astype(np.uint8) 28 | outputs.append(x) 29 | 30 | os.makedirs(os.path.dirname(path), exist_ok=True) 31 | imageio.mimsave(path, outputs, fps=fps) 32 | 33 | 34 | # DDIM Inversion 35 | @torch.no_grad() 36 | def init_prompt(prompt, pipeline): 37 | uncond_input = pipeline.tokenizer( 38 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 39 | return_tensors="pt" 40 | ) 41 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 42 | text_input = pipeline.tokenizer( 43 | [prompt], 44 | padding="max_length", 45 | max_length=pipeline.tokenizer.model_max_length, 46 | truncation=True, 47 | return_tensors="pt", 48 | ) 49 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 50 | context = torch.cat([uncond_embeddings, text_embeddings]) 51 | 52 | return context 53 | 54 | 55 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 56 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 57 | timestep, next_timestep = min( 58 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 59 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 60 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 61 | beta_prod_t = 1 - alpha_prod_t 62 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 63 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 64 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 65 | return next_sample 66 | 67 | 68 | def get_noise_pred_single(latents, t, context, unet): 69 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 70 | return noise_pred 71 | 72 | 73 | @torch.no_grad() 74 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 75 | context = init_prompt(prompt, pipeline) 76 | uncond_embeddings, cond_embeddings = context.chunk(2) 77 | all_latent = [latent] 78 | latent = latent.clone().detach() 79 | for i in tqdm(range(num_inv_steps)): 80 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 81 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 82 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 83 | all_latent.append(latent) 84 | return all_latent 85 | 86 | 87 | @torch.no_grad() 88 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 89 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 90 | return ddim_latents 91 | 92 | 93 | # def load_weights( 94 | # animation_pipeline, 95 | # # motion module 96 | # motion_module_path = "", 97 | # motion_module_lora_configs = [], 98 | # # image layers 99 | # dreambooth_model_path = "", 100 | # lora_model_path = "", 101 | # lora_alpha = 0.8, 102 | # ): 103 | # # 1.1 motion module 104 | # unet_state_dict = {} 105 | # if motion_module_path != "": 106 | # print(f"load motion module from {motion_module_path}") 107 | # motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 108 | # motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 109 | # unet_state_dict.update({name.replace("module.", ""): param for name, param in motion_module_state_dict.items()}) 110 | 111 | # missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) 112 | # assert len(unexpected) == 0 113 | # del unet_state_dict 114 | 115 | # # if dreambooth_model_path != "": 116 | # # print(f"load dreambooth model from {dreambooth_model_path}") 117 | # # if dreambooth_model_path.endswith(".safetensors"): 118 | # # dreambooth_state_dict = {} 119 | # # with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 120 | # # for key in f.keys(): 121 | # # dreambooth_state_dict[key.replace("module.", "")] = f.get_tensor(key) 122 | # # elif dreambooth_model_path.endswith(".ckpt"): 123 | # # dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 124 | # # dreambooth_state_dict = {k.replace("module.", ""): v for k, v in dreambooth_state_dict.items()} 125 | 126 | # # 1. vae 127 | # # converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) 128 | # # animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) 129 | # # 2. unet 130 | # # converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) 131 | # # animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 132 | # # 3. text_model 133 | # # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 134 | # # del dreambooth_state_dict 135 | 136 | # if lora_model_path != "": 137 | # print(f"load lora model from {lora_model_path}") 138 | # assert lora_model_path.endswith(".safetensors") 139 | # lora_state_dict = {} 140 | # with safe_open(lora_model_path, framework="pt", device="cpu") as f: 141 | # for key in f.keys(): 142 | # lora_state_dict[key.replace("module.", "")] = f.get_tensor(key) 143 | 144 | # animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) 145 | # del lora_state_dict 146 | 147 | # for motion_module_lora_config in motion_module_lora_configs: 148 | # path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] 149 | # print(f"load motion LoRA from {path}") 150 | 151 | # motion_lora_state_dict = torch.load(path, map_location="cpu") 152 | # motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict 153 | # motion_lora_state_dict = {k.replace("module.", ""): v for k, v in motion_lora_state_dict.items()} 154 | 155 | # animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha) 156 | 157 | # return animation_pipeline 158 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/README.md: -------------------------------------------------------------------------------- 1 | # 🌀 ControlNeXt-SVD-v2 2 | 3 | This is our implementation of ControlNeXt based on [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). It can be seen as an attempt to replicate the implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) with a more concise and efficient architecture. 4 | 5 | Compared to image generation, video generation poses significantly greater challenges. While direct training of the generation model using our method is feasible, we also employ various engineering strategies to enhance performance. Although they are irrespective of academic algorithms. 6 | 7 | 8 | > Please refer to [Examples](#examples) for further intuitive details.\ 9 | > Please refer to [Base model](#base-model) for more details of our used base model. \ 10 | > Please refer to [Inference](#inference) for more details regarding installation and inference.\ 11 | > Please refer to [Advanced Performance](#advanced-performance) for more details to achieve a better performance.\ 12 | > Please refer to [Limitations](#limitations) for more details about the limitations of current work. 13 | 14 | # Examples 15 | If you can't load the videos, you can also directly download them from [here](examples/demos) and [here](examples/video). 16 | Or you can view them from our [Project Page](https://pbihao.github.io/projects/controlnext/index.html) or [BiliBili](https://www.bilibili.com/video/BV1wJYbebEE7/?buvid=YC4E03C93B119ADD4080B0958DE73F9DDCAC&from_spmid=dt.dt.video.0&is_story_h5=false&mid=y82Gz7uArS6jTQ6zuqJj3w%3D%3D&p=1&plat_id=114&share_from=ugc&share_medium=iphone&share_plat=ios&share_session_id=4E5549FC-0710-4030-BD2C-CDED80B46D08&share_source=WEIXIN&share_source=weixin&share_tag=s_i×tamp=1723123770&unique_k=XLZLhCq&up_id=176095810&vd_source=3791450598e16da25ecc2477fc7983db). 17 | 18 | 19 | 20 | 23 | 26 | 27 | 28 | 31 | 34 | 35 | 36 |
21 | 22 | 24 | 25 |
29 | 30 | 32 | 33 |
37 | 38 | 39 | 40 | 41 | 42 | 43 | # Base Model 44 | 45 | For the v2 version, we adopt the below operations to improve the performance: 46 | * We have collected a higher-quality dataset with higher resolution to train our model. 47 | * We have extended the training and inference batch frames to 24. 48 | * We have extended the video height and width to a resolution of 576 × 1024. 49 | * We conduct extensive continual training of SVD on human-related videos to enhance its ability to generate human-related content. 50 | * We adopt fp32. 51 | * We adopt the pose alignment during the inference following the related. 52 | 53 | # Inference 54 | 55 | 1. Clone our repository 56 | 2. `cd ControlNeXt-SVD-v2` 57 | 3. Download the pretrained weight into `pretrained/` from [here](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/v2). (More details please refer to [Base Model](#base-model)) 58 | 4. Download the DWPose weights including the [dw-ll_ucoco_384](https://drive.google.com/file/d/12L8E2oAgZy4VACGSK9RaZBZrfgx7VTA2/view?usp=sharing) and [yolox_l](https://drive.google.com/file/d/1w9pXC8tT0p9ndMN-CArp1__b2GbzewWI/view?usp=sharing) into `pretrained/DWPose`. For more details, please refer to [DWPose](https://github.com/IDEA-Research/DWPose): 59 | ``` 60 | ├───pretrained 61 | └───DWPose 62 | | │───dw-ll_ucoco_384.onnx 63 | | └───yolox_l.onnx 64 | | 65 | ├───unet.bin 66 | └───controlnet.bin 67 | ``` 68 | 5. Run the scipt 69 | 70 | ```bash 71 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 72 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \ 73 | --output_dir outputs \ 74 | --max_frame_num 240 \ 75 | --guidance_scale 3 \ 76 | --batch_frames 24 \ 77 | --sample_stride 2 \ 78 | --overlap 6 \ 79 | --height 1024 \ 80 | --width 576 \ 81 | --controlnext_path pretrained/controlnet.bin \ 82 | --unet_path pretrained/unet.bin \ 83 | --validation_control_video_path examples/video/02.mp4 \ 84 | --ref_image_path examples/ref_imgs/01.jpeg 85 | ``` 86 | 87 | > --pretrained_model_name_or_path : pretrained base model, we pretrain and fintune models based on [SVD-XT1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1)\ 88 | > --controlnet_model_name_or_path : the model path of controlnet (a light weight module) \ 89 | > --unet_model_name_or_path : the model path of unet \ 90 | > --ref_image_path: the path to the reference image \ 91 | > --overlap: The length of the overlapped frames for long-frame video generation. \ 92 | > --sample_stride: The length of the sampled stride for the conditional controls. You can set it to `1` to make more smooth generation wihile requires more computation. 93 | 94 | 5. Face Enhancement (Optional,Recommand for bad faces) 95 | 96 | > Currently, the model is not specifically trained for IP consistency, as there are already many mature tools available. Additionally, alternatives like Animate Anyone also adopt such post-processing techniques. 97 | 98 | a. Clone [Face Fusion](https://github.com/facefusion/facefusion): \ 99 | ```git clone https://github.com/facefusion/facefusion``` 100 | 101 | b. Ensure to enter the directory:\ 102 | ```cd facefusion``` 103 | 104 | c. Install facefusion (Recommand create a new virtual environment using conda to avoid conflicts):\ 105 | ```python install.py``` 106 | 107 | d. Run the command: 108 | ``` 109 | python run.py \ 110 | -s ../outputs/demo.jpg \ 111 | -t ../outputs/demo.mp4 \ 112 | -o ../outputs/out.mp4 \ 113 | --headless \ 114 | --execution-providers cuda \ 115 | --face-selector-mode one 116 | ``` 117 | 118 | > -s: the reference image \ 119 | > -t: the path to the original video\ 120 | > -o: the path to store the refined video\ 121 | > --headless: no gui need\ 122 | > --execution-providers cuda: use cuda for acceleration (If available, most the cpu is enough) 123 | 124 | # Advanced Performance 125 | In this section, we will delve into additional details and my own experiences to enhance video generation. These factors are algorithm-independent and unrelated to academia, yet crucial for achieving superior results. Many closely related works incorporate these strategies. 126 | 127 | ### Reference Image 128 | 129 | It is crucial to ensure that the reference image is clear and easily understandable, especially aligning the face of the reference with the pose. 130 | 131 | 132 | ### Face Enhencement 133 | 134 | Most related works utilize face enhancement as part of the post-processing. This is especially relevant when generating videos based on images of unfamiliar individuals, such as friends, who were not included in the base model's pretraining and are therefore unseen and OOD data. 135 | 136 | We recommand the [Facefusion](https://github.com/facefusion/facefusion 137 | ) for the post proct-processing. And please let us know if you have a better solution. 138 | 139 | Please refer to [Facefusion](https://github.com/facefusion/facefusion 140 | ) for more details. 141 | 142 | ![Facefusion](examples/facefusion/facefusion.jpg) 143 | 144 | 145 | ### Continuously Finetune 146 | 147 | To significantly enhance performance on a specific pose sequence, you can continuously fine-tune the model for just a few hundred steps. 148 | 149 | We will release the related fine-tuning code later. 150 | 151 | ### Pose Generation 152 | 153 | We adopt [DWPose](https://github.com/IDEA-Research/DWPose) for the pose generation, and follow the related work ([1](https://humanaigc.github.io/animate-anyone/), [2](https://tencent.github.io/MimicMotion/)) to align the pose. 154 | 155 | # Limitations 156 | 157 | ## IP Consistency 158 | 159 | We did not prioritize maintaining IP consistency during the development of the generation model and now rely on a helper model for face enhancement. 160 | 161 | However, additional training can be implemented to ensure IP consistency moving forward. 162 | 163 | This also leaves a possible direction for futher improvement. 164 | 165 | ## Base model 166 | 167 | The base model plays a crucial role in generating human features, particularly hands and faces. We encourage collaboration to improve the base model for enhanced human-related video generation. 168 | 169 | # TODO 170 | 171 | * Training and finetune code 172 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/dwpose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/dwpose/__init__.py -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/dwpose/dwpose_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .wholebody import Wholebody 7 | 8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | class DWposeDetector: 12 | """ 13 | A pose detect method for image-like data. 14 | 15 | Parameters: 16 | model_det: (str) serialized ONNX format model path, 17 | such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx 18 | model_pose: (str) serialized ONNX format model path, 19 | such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx 20 | device: (str) 'cpu' or 'cuda:{device_id}' 21 | """ 22 | def __init__(self, model_det, model_pose, device='cpu'): 23 | self.args = model_det, model_pose, device 24 | 25 | def release_memory(self): 26 | if hasattr(self, 'pose_estimation'): 27 | del self.pose_estimation 28 | import gc; gc.collect() 29 | 30 | def __call__(self, oriImg): 31 | if not hasattr(self, 'pose_estimation'): 32 | self.pose_estimation = Wholebody(*self.args) 33 | 34 | oriImg = oriImg.copy() 35 | H, W, C = oriImg.shape 36 | with torch.no_grad(): 37 | candidate, score = self.pose_estimation(oriImg) 38 | nums, _, locs = candidate.shape 39 | candidate[..., 0] /= float(W) 40 | candidate[..., 1] /= float(H) 41 | body = candidate[:, :18].copy() 42 | body = body.reshape(nums * 18, locs) 43 | subset = score[:, :18].copy() 44 | for i in range(len(subset)): 45 | for j in range(len(subset[i])): 46 | if subset[i][j] > 0.3: 47 | subset[i][j] = int(18 * i + j) 48 | else: 49 | subset[i][j] = -1 50 | 51 | # un_visible = subset < 0.3 52 | # candidate[un_visible] = -1 53 | 54 | # foot = candidate[:, 18:24] 55 | 56 | faces = candidate[:, 24:92] 57 | 58 | hands = candidate[:, 92:113] 59 | hands = np.vstack([hands, candidate[:, 113:]]) 60 | 61 | faces_score = score[:, 24:92] 62 | hands_score = np.vstack([score[:, 92:113], score[:, 113:]]) 63 | 64 | bodies = dict(candidate=body, subset=subset, score=score[:, :18]) 65 | pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score) 66 | 67 | return pose 68 | 69 | dwpose_detector = DWposeDetector( 70 | model_det="pretrained/DWPose/yolox_l.onnx", 71 | model_pose="pretrained/DWPose/dw-ll_ucoco_384.onnx", 72 | device=device) 73 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def nms(boxes, scores, nms_thr): 6 | """Single class NMS implemented in Numpy. 7 | 8 | Args: 9 | boxes (np.ndarray): shape=(N,4); N is number of boxes 10 | scores (np.ndarray): the score of bboxes 11 | nms_thr (float): the threshold in NMS 12 | 13 | Returns: 14 | List[int]: output bbox ids 15 | """ 16 | x1 = boxes[:, 0] 17 | y1 = boxes[:, 1] 18 | x2 = boxes[:, 2] 19 | y2 = boxes[:, 3] 20 | 21 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 22 | order = scores.argsort()[::-1] 23 | 24 | keep = [] 25 | while order.size > 0: 26 | i = order[0] 27 | keep.append(i) 28 | xx1 = np.maximum(x1[i], x1[order[1:]]) 29 | yy1 = np.maximum(y1[i], y1[order[1:]]) 30 | xx2 = np.minimum(x2[i], x2[order[1:]]) 31 | yy2 = np.minimum(y2[i], y2[order[1:]]) 32 | 33 | w = np.maximum(0.0, xx2 - xx1 + 1) 34 | h = np.maximum(0.0, yy2 - yy1 + 1) 35 | inter = w * h 36 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 37 | 38 | inds = np.where(ovr <= nms_thr)[0] 39 | order = order[inds + 1] 40 | 41 | return keep 42 | 43 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 44 | """Multiclass NMS implemented in Numpy. Class-aware version. 45 | 46 | Args: 47 | boxes (np.ndarray): shape=(N,4); N is number of boxes 48 | scores (np.ndarray): the score of bboxes 49 | nms_thr (float): the threshold in NMS 50 | score_thr (float): the threshold of cls score 51 | 52 | Returns: 53 | np.ndarray: outputs bboxes coordinate 54 | """ 55 | final_dets = [] 56 | num_classes = scores.shape[1] 57 | for cls_ind in range(num_classes): 58 | cls_scores = scores[:, cls_ind] 59 | valid_score_mask = cls_scores > score_thr 60 | if valid_score_mask.sum() == 0: 61 | continue 62 | else: 63 | valid_scores = cls_scores[valid_score_mask] 64 | valid_boxes = boxes[valid_score_mask] 65 | keep = nms(valid_boxes, valid_scores, nms_thr) 66 | if len(keep) > 0: 67 | cls_inds = np.ones((len(keep), 1)) * cls_ind 68 | dets = np.concatenate( 69 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 70 | ) 71 | final_dets.append(dets) 72 | if len(final_dets) == 0: 73 | return None 74 | return np.concatenate(final_dets, 0) 75 | 76 | def demo_postprocess(outputs, img_size, p6=False): 77 | grids = [] 78 | expanded_strides = [] 79 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 80 | 81 | hsizes = [img_size[0] // stride for stride in strides] 82 | wsizes = [img_size[1] // stride for stride in strides] 83 | 84 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 85 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 86 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 87 | grids.append(grid) 88 | shape = grid.shape[:2] 89 | expanded_strides.append(np.full((*shape, 1), stride)) 90 | 91 | grids = np.concatenate(grids, 1) 92 | expanded_strides = np.concatenate(expanded_strides, 1) 93 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 94 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 95 | 96 | return outputs 97 | 98 | def preprocess(img, input_size, swap=(2, 0, 1)): 99 | if len(img.shape) == 3: 100 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 101 | else: 102 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 103 | 104 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 105 | resized_img = cv2.resize( 106 | img, 107 | (int(img.shape[1] * r), int(img.shape[0] * r)), 108 | interpolation=cv2.INTER_LINEAR, 109 | ).astype(np.uint8) 110 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 111 | 112 | padded_img = padded_img.transpose(swap) 113 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 114 | return padded_img, r 115 | 116 | def inference_detector(session, oriImg): 117 | """run human detect 118 | """ 119 | input_shape = (640,640) 120 | img, ratio = preprocess(oriImg, input_shape) 121 | 122 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 123 | output = session.run(None, ort_inputs) 124 | predictions = demo_postprocess(output[0], input_shape)[0] 125 | 126 | boxes = predictions[:, :4] 127 | scores = predictions[:, 4:5] * predictions[:, 5:] 128 | 129 | boxes_xyxy = np.ones_like(boxes) 130 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 131 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 132 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 133 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 134 | boxes_xyxy /= ratio 135 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 136 | if dets is not None: 137 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 138 | isscore = final_scores>0.3 139 | iscat = final_cls_inds == 0 140 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 141 | final_boxes = final_boxes[isbbox] 142 | else: 143 | final_boxes = np.array([]) 144 | 145 | return final_boxes 146 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/dwpose/preprocess.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import decord 3 | import numpy as np 4 | 5 | from .util import draw_pose 6 | from .dwpose_detector import dwpose_detector as dwprocessor 7 | 8 | 9 | def get_video_pose( 10 | video_path: str, 11 | ref_image: np.ndarray, 12 | max_frame_num = None, 13 | sample_stride: int=1): 14 | """preprocess ref image pose and video pose 15 | 16 | Args: 17 | video_path (str): video pose path 18 | ref_image (np.ndarray): reference image 19 | sample_stride (int, optional): Defaults to 1. 20 | 21 | Returns: 22 | np.ndarray: sequence of video pose 23 | """ 24 | # select ref-keypoint from reference pose for pose rescale 25 | ref_pose = dwprocessor(ref_image) 26 | ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] 27 | ref_keypoint_id = [i for i in ref_keypoint_id \ 28 | if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0] 29 | ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id] 30 | 31 | height, width, _ = ref_image.shape 32 | 33 | # read input video 34 | vr = decord.VideoReader(video_path, ctx=decord.cpu(0)) 35 | # sample_stride *= max(1, int(vr.get_avg_fps() / 24)) 36 | if max_frame_num is None: 37 | max_frame_num = len(vr) 38 | else: 39 | max_frame_num = min(len(vr), max_frame_num * sample_stride) 40 | 41 | frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy() 42 | detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc="DWPose")] 43 | dwprocessor.release_memory() 44 | 45 | detected_bodies = np.stack( 46 | [p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:, 47 | ref_keypoint_id] 48 | # compute linear-rescale params 49 | ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1) 50 | fh, fw, _ = vr[0].shape 51 | ax = ay / (fh / fw / height * width) 52 | bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax) 53 | a = np.array([ax, ay]) 54 | b = np.array([bx, by]) 55 | output_pose = [] 56 | # pose rescale 57 | for detected_pose in detected_poses: 58 | detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b 59 | detected_pose['faces'] = detected_pose['faces'] * a + b 60 | detected_pose['hands'] = detected_pose['hands'] * a + b 61 | im = draw_pose(detected_pose, height, width) 62 | output_pose.append(np.array(im)) 63 | return np.stack(output_pose) 64 | 65 | 66 | def get_image_pose(ref_image): 67 | """process image pose 68 | 69 | Args: 70 | ref_image (np.ndarray): reference image pixel value 71 | 72 | Returns: 73 | np.ndarray: pose visual image in RGB-mode 74 | """ 75 | height, width, _ = ref_image.shape 76 | ref_pose = dwprocessor(ref_image) 77 | pose_img = draw_pose(ref_pose, height, width) 78 | return np.array(pose_img) 79 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/dwpose/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import cv2 5 | 6 | 7 | eps = 0.01 8 | 9 | def alpha_blend_color(color, alpha): 10 | """blend color according to point conf 11 | """ 12 | return [int(c * alpha) for c in color] 13 | 14 | def draw_bodypose(canvas, candidate, subset, score): 15 | H, W, C = canvas.shape 16 | candidate = np.array(candidate) 17 | subset = np.array(subset) 18 | 19 | stickwidth = 4 20 | 21 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 22 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 23 | [1, 16], [16, 18], [3, 17], [6, 18]] 24 | 25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 28 | 29 | for i in range(17): 30 | for n in range(len(subset)): 31 | index = subset[n][np.array(limbSeq[i]) - 1] 32 | conf = score[n][np.array(limbSeq[i]) - 1] 33 | if conf[0] < 0.3 or conf[1] < 0.3: 34 | continue 35 | Y = candidate[index.astype(int), 0] * float(W) 36 | X = candidate[index.astype(int), 1] * float(H) 37 | mX = np.mean(X) 38 | mY = np.mean(Y) 39 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 40 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 41 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 42 | cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1])) 43 | 44 | canvas = (canvas * 0.6).astype(np.uint8) 45 | 46 | for i in range(18): 47 | for n in range(len(subset)): 48 | index = int(subset[n][i]) 49 | if index == -1: 50 | continue 51 | x, y = candidate[index][0:2] 52 | conf = score[n][i] 53 | x = int(x * W) 54 | y = int(y * H) 55 | cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1) 56 | 57 | return canvas 58 | 59 | def draw_handpose(canvas, all_hand_peaks, all_hand_scores): 60 | H, W, C = canvas.shape 61 | 62 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 63 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 64 | 65 | for peaks, scores in zip(all_hand_peaks, all_hand_scores): 66 | 67 | for ie, e in enumerate(edges): 68 | x1, y1 = peaks[e[0]] 69 | x2, y2 = peaks[e[1]] 70 | x1 = int(x1 * W) 71 | y1 = int(y1 * H) 72 | x2 = int(x2 * W) 73 | y2 = int(y2 * H) 74 | score = int(scores[e[0]] * scores[e[1]] * 255) 75 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 76 | cv2.line(canvas, (x1, y1), (x2, y2), 77 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2) 78 | 79 | for i, keyponit in enumerate(peaks): 80 | x, y = keyponit 81 | x = int(x * W) 82 | y = int(y * H) 83 | score = int(scores[i] * 255) 84 | if x > eps and y > eps: 85 | cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1) 86 | return canvas 87 | 88 | def draw_facepose(canvas, all_lmks, all_scores): 89 | H, W, C = canvas.shape 90 | for lmks, scores in zip(all_lmks, all_scores): 91 | for lmk, score in zip(lmks, scores): 92 | x, y = lmk 93 | x = int(x * W) 94 | y = int(y * H) 95 | conf = int(score * 255) 96 | if x > eps and y > eps: 97 | cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1) 98 | return canvas 99 | 100 | def draw_pose(pose, H, W, ref_w=2160): 101 | """vis dwpose outputs 102 | 103 | Args: 104 | pose (List): DWposeDetector outputs in dwpose_detector.py 105 | H (int): height 106 | W (int): width 107 | ref_w (int, optional) Defaults to 2160. 108 | 109 | Returns: 110 | np.ndarray: image pixel value in RGB mode 111 | """ 112 | bodies = pose['bodies'] 113 | faces = pose['faces'] 114 | hands = pose['hands'] 115 | candidate = bodies['candidate'] 116 | subset = bodies['subset'] 117 | 118 | sz = min(H, W) 119 | sr = (ref_w / sz) if sz != ref_w else 1 120 | 121 | ########################################## create zero canvas ################################################## 122 | canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8) 123 | 124 | ########################################### draw body pose ##################################################### 125 | canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score']) 126 | 127 | ########################################### draw hand pose ##################################################### 128 | canvas = draw_handpose(canvas, hands, pose['hands_score']) 129 | 130 | ########################################### draw face pose ##################################################### 131 | canvas = draw_facepose(canvas, faces, pose['faces_score']) 132 | 133 | return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1) 134 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnxruntime as ort 3 | 4 | from .onnxdet import inference_detector 5 | from .onnxpose import inference_pose 6 | 7 | 8 | class Wholebody: 9 | """detect human pose by dwpose 10 | """ 11 | def __init__(self, model_det, model_pose, device="cpu"): 12 | providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] 13 | provider_options = None if device == 'cpu' else [{'device_id': 0}] 14 | 15 | self.session_det = ort.InferenceSession( 16 | path_or_bytes=model_det, providers=providers, provider_options=provider_options 17 | ) 18 | self.session_pose = ort.InferenceSession( 19 | path_or_bytes=model_pose, providers=providers, provider_options=provider_options 20 | ) 21 | 22 | def __call__(self, oriImg): 23 | """call to process dwpose-detect 24 | 25 | Args: 26 | oriImg (np.ndarray): detected image 27 | 28 | """ 29 | det_result = inference_detector(self.session_det, oriImg) 30 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) 31 | 32 | keypoints_info = np.concatenate( 33 | (keypoints, scores[..., None]), axis=-1) 34 | # compute neck joint 35 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 36 | # neck score when visualizing pred 37 | neck[:, 2:4] = np.logical_and( 38 | keypoints_info[:, 5, 2:4] > 0.3, 39 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 40 | new_keypoints_info = np.insert( 41 | keypoints_info, 17, neck, axis=1) 42 | mmpose_idx = [ 43 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 44 | ] 45 | openpose_idx = [ 46 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 47 | ] 48 | new_keypoints_info[:, openpose_idx] = \ 49 | new_keypoints_info[:, mmpose_idx] 50 | keypoints_info = new_keypoints_info 51 | 52 | keypoints, scores = keypoints_info[ 53 | ..., :2], keypoints_info[..., 2] 54 | 55 | return keypoints, scores 56 | 57 | 58 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/demos/01-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/demos/01-1.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/demos/02-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/demos/02-1.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/demos/03-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/demos/03-1.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/demos/04-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/demos/04-1.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/facefusion/facefusion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/facefusion/facefusion.jpg -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/ref_imgs/01.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/ref_imgs/01.jpeg -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/ref_imgs/02.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/ref_imgs/02.jpeg -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/ref_imgs/03.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/ref_imgs/03.jpeg -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/ref_imgs/04.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/ref_imgs/04.jpeg -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/video/01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/video/01.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/examples/video/02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD-v2/examples/video/02.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/models/controlnext_vid_svd.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D 10 | 11 | 12 | class ControlNeXtSDVModel(ModelMixin, ConfigMixin): 13 | _supports_gradient_checkpointing = True 14 | 15 | @register_to_config 16 | def __init__( 17 | self, 18 | time_embed_dim = 256, 19 | in_channels = [128, 128], 20 | out_channels = [128, 256], 21 | groups = [4, 8] 22 | ): 23 | super().__init__() 24 | 25 | self.time_proj = Timesteps(128, True, downscale_freq_shift=0) 26 | self.time_embedding = TimestepEmbedding(128, time_embed_dim) 27 | self.embedding = nn.Sequential( 28 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 29 | nn.GroupNorm(2, 64), 30 | nn.ReLU(), 31 | nn.Conv2d(64, 64, kernel_size=3), 32 | nn.GroupNorm(2, 64), 33 | nn.ReLU(), 34 | nn.Conv2d(64, 128, kernel_size=3), 35 | nn.GroupNorm(2, 128), 36 | nn.ReLU(), 37 | ) 38 | 39 | self.down_res = nn.ModuleList() 40 | self.down_sample = nn.ModuleList() 41 | for i in range(len(in_channels)): 42 | self.down_res.append( 43 | ResnetBlock2D( 44 | in_channels=in_channels[i], 45 | out_channels=out_channels[i], 46 | temb_channels=time_embed_dim, 47 | groups=groups[i] 48 | ), 49 | ) 50 | self.down_sample.append( 51 | Downsample2D( 52 | out_channels[i], 53 | use_conv=True, 54 | out_channels=out_channels[i], 55 | padding=1, 56 | name="op", 57 | ) 58 | ) 59 | 60 | self.mid_convs = nn.ModuleList() 61 | self.mid_convs.append(nn.Sequential( 62 | nn.Conv2d( 63 | in_channels=out_channels[-1], 64 | out_channels=out_channels[-1], 65 | kernel_size=3, 66 | stride=1, 67 | padding=1 68 | ), 69 | nn.ReLU(), 70 | nn.GroupNorm(8, out_channels[-1]), 71 | nn.Conv2d( 72 | in_channels=out_channels[-1], 73 | out_channels=out_channels[-1], 74 | kernel_size=3, 75 | stride=1, 76 | padding=1 77 | ), 78 | nn.GroupNorm(8, out_channels[-1]), 79 | )) 80 | self.mid_convs.append( 81 | nn.Conv2d( 82 | in_channels=out_channels[-1], 83 | out_channels=320, 84 | kernel_size=1, 85 | stride=1, 86 | )) 87 | 88 | self.scale = 1. 89 | 90 | def _set_gradient_checkpointing(self, module, value=False): 91 | if hasattr(module, "gradient_checkpointing"): 92 | module.gradient_checkpointing = value 93 | 94 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 95 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 96 | """ 97 | Sets the attention processor to use [feed forward 98 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 99 | 100 | Parameters: 101 | chunk_size (`int`, *optional*): 102 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 103 | over each tensor of dim=`dim`. 104 | dim (`int`, *optional*, defaults to `0`): 105 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 106 | or dim=1 (sequence length). 107 | """ 108 | if dim not in [0, 1]: 109 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 110 | 111 | # By default chunk size is 1 112 | chunk_size = chunk_size or 1 113 | 114 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 115 | if hasattr(module, "set_chunk_feed_forward"): 116 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 117 | 118 | for child in module.children(): 119 | fn_recursive_feed_forward(child, chunk_size, dim) 120 | 121 | for module in self.children(): 122 | fn_recursive_feed_forward(module, chunk_size, dim) 123 | 124 | def forward( 125 | self, 126 | sample: torch.FloatTensor, 127 | timestep: Union[torch.Tensor, float, int], 128 | ): 129 | 130 | timesteps = timestep 131 | if not torch.is_tensor(timesteps): 132 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 133 | # This would be a good case for the `match` statement (Python 3.10+) 134 | is_mps = sample.device.type == "mps" 135 | if isinstance(timestep, float): 136 | dtype = torch.float32 if is_mps else torch.float64 137 | else: 138 | dtype = torch.int32 if is_mps else torch.int64 139 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 140 | elif len(timesteps.shape) == 0: 141 | timesteps = timesteps[None].to(sample.device) 142 | 143 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 144 | batch_size, num_frames = sample.shape[:2] 145 | timesteps = timesteps.expand(batch_size) 146 | 147 | t_emb = self.time_proj(timesteps) 148 | 149 | # `Timesteps` does not contain any weights and will always return f32 tensors 150 | # but time_embedding might actually be running in fp16. so we need to cast here. 151 | # there might be better ways to encapsulate this. 152 | t_emb = t_emb.to(dtype=sample.dtype) 153 | 154 | emb_batch = self.time_embedding(t_emb) 155 | 156 | # Flatten the batch and frames dimensions 157 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 158 | sample = sample.flatten(0, 1) 159 | # Repeat the embeddings num_video_frames times 160 | # emb: [batch, channels] -> [batch * frames, channels] 161 | emb = emb_batch.repeat_interleave(num_frames, dim=0) 162 | 163 | sample = self.embedding(sample) 164 | 165 | for res, downsample in zip(self.down_res, self.down_sample): 166 | sample = res(sample, emb) 167 | sample = downsample(sample, emb) 168 | 169 | sample = self.mid_convs[0](sample) + sample 170 | sample = self.mid_convs[1](sample) 171 | 172 | return { 173 | 'output': sample, 174 | 'scale': self.scale, 175 | } 176 | 177 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/run_controlnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt 6 | from models.controlnext_vid_svd import ControlNeXtSDVModel 7 | from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel 8 | from transformers import CLIPVisionModelWithProjection 9 | import re 10 | from diffusers import AutoencoderKLTemporalDecoder 11 | from moviepy.editor import ImageSequenceClip 12 | from decord import VideoReader 13 | import argparse 14 | from safetensors.torch import load_file 15 | from utils.pre_process import preprocess 16 | 17 | 18 | def write_mp4(video_path, samples, fps=14, audio_bitrate="192k"): 19 | clip = ImageSequenceClip(samples, fps=fps) 20 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, 21 | ffmpeg_params=["-crf", "18", "-preset", "slow"]) 22 | 23 | def save_vid_side_by_side(batch_output, validation_control_images, output_folder, fps): 24 | # Helper function to convert tensors to PIL images and save as GIF 25 | flattened_batch_output = [img for sublist in batch_output for img in sublist] 26 | video_path = output_folder+'/test_1.mp4' 27 | final_images = [] 28 | outputs = [] 29 | # Helper function to concatenate images horizontally 30 | def get_concat_h(im1, im2): 31 | dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height))) 32 | dst.paste(im1, (0, 0)) 33 | dst.paste(im2, (im1.width, 0)) 34 | return dst 35 | for image_list in zip(validation_control_images, flattened_batch_output): 36 | predict_img = image_list[1].resize(image_list[0].size) 37 | result = get_concat_h(image_list[0], predict_img) 38 | final_images.append(np.array(result)) 39 | outputs.append(np.array(predict_img)) 40 | write_mp4(video_path, final_images, fps=fps) 41 | 42 | output_path = output_folder + "/output.mp4" 43 | write_mp4(output_path, outputs, fps=fps) 44 | 45 | 46 | def load_images_from_folder_to_pil(folder): 47 | images = [] 48 | valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed 49 | 50 | # Function to extract frame number from the filename 51 | def frame_number(filename): 52 | # First, try the pattern 'frame_x_7fps' 53 | new_pattern_match = re.search(r'frame_(\d+)_7fps', filename) 54 | if new_pattern_match: 55 | return int(new_pattern_match.group(1)) 56 | # If the new pattern is not found, use the original digit extraction method 57 | matches = re.findall(r'\d+', filename) 58 | if matches: 59 | if matches[-1] == '0000' and len(matches) > 1: 60 | return int(matches[-2]) # Return the second-to-last sequence if the last is '0000' 61 | return int(matches[-1]) # Otherwise, return the last sequence 62 | return float('inf') # Return 'inf' 63 | 64 | # Sorting files based on frame number 65 | sorted_files = sorted(os.listdir(folder), key=frame_number) 66 | # Load images in sorted order 67 | for filename in sorted_files: 68 | ext = os.path.splitext(filename)[1].lower() 69 | if ext in valid_extensions: 70 | img = Image.open(os.path.join(folder, filename)).convert('RGB') 71 | images.append(img) 72 | 73 | return images 74 | 75 | 76 | def load_images_from_video_to_pil(video_path): 77 | images = [] 78 | 79 | vr = VideoReader(video_path) 80 | length = len(vr) 81 | 82 | for idx in range(length): 83 | frame = vr[idx].asnumpy() 84 | images.append(Image.fromarray(frame)) 85 | return images 86 | 87 | 88 | def parse_args(): 89 | parser = argparse.ArgumentParser( 90 | description="Script to train Stable Diffusion XL for InstructPix2Pix." 91 | ) 92 | 93 | parser.add_argument( 94 | "--pretrained_model_name_or_path", 95 | type=str, 96 | default=None, 97 | required=True 98 | ) 99 | 100 | parser.add_argument( 101 | "--validation_control_images_folder", 102 | type=str, 103 | default=None, 104 | required=False, 105 | ) 106 | 107 | parser.add_argument( 108 | "--validation_control_video_path", 109 | type=str, 110 | default=None, 111 | required=False, 112 | ) 113 | 114 | parser.add_argument( 115 | "--output_dir", 116 | type=str, 117 | default=None, 118 | required=True 119 | ) 120 | 121 | parser.add_argument( 122 | "--height", 123 | type=int, 124 | default=768, 125 | required=False 126 | ) 127 | 128 | parser.add_argument( 129 | "--width", 130 | type=int, 131 | default=512, 132 | required=False 133 | ) 134 | 135 | parser.add_argument( 136 | "--guidance_scale", 137 | type=float, 138 | default=2., 139 | required=False 140 | ) 141 | 142 | parser.add_argument( 143 | "--num_inference_steps", 144 | type=int, 145 | default=25, 146 | required=False 147 | ) 148 | 149 | 150 | parser.add_argument( 151 | "--controlnext_path", 152 | type=str, 153 | default=None, 154 | required=True 155 | ) 156 | 157 | parser.add_argument( 158 | "--unet_path", 159 | type=str, 160 | default=None, 161 | required=True 162 | ) 163 | 164 | parser.add_argument( 165 | "--max_frame_num", 166 | type=int, 167 | default=50, 168 | required=False 169 | ) 170 | 171 | parser.add_argument( 172 | "--ref_image_path", 173 | type=str, 174 | default=None, 175 | required=True 176 | ) 177 | 178 | parser.add_argument( 179 | "--batch_frames", 180 | type=int, 181 | default=14, 182 | required=False 183 | ) 184 | 185 | parser.add_argument( 186 | "--overlap", 187 | type=int, 188 | default=4, 189 | required=False 190 | ) 191 | 192 | parser.add_argument( 193 | "--sample_stride", 194 | type=int, 195 | default=2, 196 | required=False 197 | ) 198 | 199 | args = parser.parse_args() 200 | return args 201 | 202 | 203 | def load_tensor(tensor_path): 204 | if os.path.splitext(tensor_path)[1] == '.bin': 205 | return torch.load(tensor_path) 206 | elif os.path.splitext(tensor_path)[1] == ".safetensors": 207 | return load_file(tensor_path) 208 | else: 209 | print("without supported tensors") 210 | os._exit() 211 | 212 | 213 | # Main script 214 | if __name__ == "__main__": 215 | args = parse_args() 216 | 217 | assert (args.validation_control_images_folder is None) ^ (args.validation_control_video_path is None), "must and only one of [validation_control_images_folder, validation_control_video_path] should be given" 218 | 219 | unet = UNetSpatioTemporalConditionControlNeXtModel.from_pretrained( 220 | args.pretrained_model_name_or_path, 221 | subfolder="unet", 222 | low_cpu_mem_usage=True, 223 | ) 224 | controlnext = ControlNeXtSDVModel() 225 | controlnext.load_state_dict(load_tensor(args.controlnext_path)) 226 | unet.load_state_dict(load_tensor(args.unet_path), strict=False) 227 | 228 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 229 | args.pretrained_model_name_or_path, subfolder="image_encoder") 230 | vae = AutoencoderKLTemporalDecoder.from_pretrained( 231 | args.pretrained_model_name_or_path, subfolder="vae") 232 | 233 | pipeline = StableVideoDiffusionPipelineControlNeXt.from_pretrained( 234 | args.pretrained_model_name_or_path, 235 | controlnext=controlnext, 236 | unet=unet, 237 | vae=vae, 238 | image_encoder=image_encoder) 239 | # pipeline.to(dtype=torch.float16) 240 | pipeline.enable_model_cpu_offload() 241 | 242 | os.makedirs(args.output_dir, exist_ok=True) 243 | 244 | # Inference and saving loop 245 | # ref_image = Image.open(args.ref_image_path).convert('RGB') 246 | # ref_image = ref_image.resize((args.width, args.height)) 247 | # validation_control_images = [img.resize((args.width, args.height)) for img in validation_control_images] 248 | 249 | validation_control_images, ref_image = preprocess(args.validation_control_video_path, args.ref_image_path, width=args.width, height=args.height, max_frame_num=args.max_frame_num, sample_stride=args.sample_stride) 250 | 251 | 252 | final_result = [] 253 | frames = args.batch_frames 254 | num_frames = min(args.max_frame_num, len(validation_control_images)) 255 | 256 | for i in range(num_frames): 257 | validation_control_images[i] = Image.fromarray(np.array(validation_control_images[i])) 258 | 259 | video_frames = pipeline( 260 | ref_image, 261 | validation_control_images[:num_frames], 262 | decode_chunk_size=2, 263 | num_frames=num_frames, 264 | motion_bucket_id=127.0, 265 | fps=7, 266 | controlnext_cond_scale=1.0, 267 | width=args.width, 268 | height=args.height, 269 | min_guidance_scale=args.guidance_scale, 270 | max_guidance_scale=args.guidance_scale, 271 | frames_per_batch=frames, 272 | num_inference_steps=args.num_inference_steps, 273 | overlap=args.overlap).frames[0] 274 | final_result.append(video_frames) 275 | 276 | fps =VideoReader(args.validation_control_video_path).get_avg_fps() // args.sample_stride 277 | 278 | save_vid_side_by_side( 279 | final_result, 280 | validation_control_images[:num_frames], 281 | args.output_dir, 282 | fps=fps) -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/script.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python run_controlnext.py \ 2 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \ 3 | --output_dir outputs \ 4 | --max_frame_num 240 \ 5 | --guidance_scale 3 \ 6 | --batch_frames 24 \ 7 | --sample_stride 2 \ 8 | --overlap 6 \ 9 | --height 1024 \ 10 | --width 576 \ 11 | --controlnext_path pretrained/controlnet.bin \ 12 | --unet_path pretrained/unet.bin \ 13 | --validation_control_video_path examples/video/02.mp4 \ 14 | --ref_image_path examples/ref_imgs/01.jpeg 15 | 16 | -------------------------------------------------------------------------------- /ControlNeXt-SVD-v2/utils/pre_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import math 5 | from omegaconf import OmegaConf 6 | from datetime import datetime 7 | from pathlib import Path 8 | from PIL import Image 9 | import numpy as np 10 | import torch.jit 11 | from torchvision.datasets.folder import pil_loader 12 | from torchvision.transforms.functional import pil_to_tensor, resize, center_crop 13 | from torchvision.transforms.functional import to_pil_image 14 | from dwpose.preprocess import get_image_pose, get_video_pose 15 | 16 | ASPECT_RATIO = 9 / 16 17 | 18 | def preprocess(video_path, image_path, width=576, height=1024, sample_stride=2, max_frame_num=None): 19 | """preprocess ref image pose and video pose 20 | 21 | Args: 22 | video_path (str): input video pose path 23 | image_path (str): reference image path 24 | resolution (int, optional): Defaults to 576. 25 | sample_stride (int, optional): Defaults to 2. 26 | """ 27 | image_pixels = pil_loader(image_path) 28 | image_pixels = pil_to_tensor(image_pixels) # (c, h, w) 29 | h, w = image_pixels.shape[-2:] 30 | ############################ compute target h/w according to original aspect ratio ############################### 31 | # if h>w: 32 | # w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64 33 | # else: 34 | # w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution 35 | w_target, h_target = width, height 36 | h_w_ratio = float(h) / float(w) 37 | if h_w_ratio < h_target / w_target: 38 | h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio) 39 | else: 40 | h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target 41 | image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) 42 | image_pixels = center_crop(image_pixels, [h_target, w_target]) 43 | image_pixels = image_pixels.permute((1, 2, 0)).numpy() 44 | ##################################### get image&video pose value ################################################# 45 | image_pose = get_image_pose(image_pixels) 46 | video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride, max_frame_num=max_frame_num) 47 | pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) 48 | # image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) 49 | image_pixels = Image.fromarray(image_pixels) 50 | pose_pixels = [Image.fromarray(p.transpose((1,2,0))) for p in pose_pixels] 51 | # return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1 52 | return pose_pixels, image_pixels 53 | 54 | -------------------------------------------------------------------------------- /ControlNeXt-SVD/README.md: -------------------------------------------------------------------------------- 1 | # 🌀 ControlNeXt-SVD 2 | 3 | This is our implementation of ControlNeXt based on [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). It can be seen as an attempt to replicate the implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) with a more concise and efficient architecture. 4 | 5 | Compared to image generation, video generation poses significantly greater challenges. While direct training of the generation model using our method is feasible, we also employ various engineering strategies to enhance performance. Although they are irrespective of academic algorithms. 6 | 7 | 8 | > Please refer to [Examples](#examples) for further intuitive details.\ 9 | > Please refer to [Base model](#base-model) for more details of our used base model. \ 10 | > Please refer to [Inference](#inference) for more details regarding installation and inference.\ 11 | > Please refer to [Advanced Performance](#advanced-performance) for more details to achieve a better performance.\ 12 | > Please refer to [Limitations](#limitations) for more details about the limitations of current work. 13 | 14 | # Examples 15 | If you can't load the videos, you can also directly download them from [here](outputs). 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | # Base Model 26 | 27 | The base model's generation capability significantly influences video generation. Initially, we train the generation model using our method, which is based on [Stable Video Diffusion XT-1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). However, the original SVD model exhibits weaknesses in generating human features, particularly in the generation of hands and faces. 28 | 29 | Therefore, we initially conduct continuous pretraining of [SVD-XT1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) on our curated collection of human-related videos to improve its ability to generate human features. Subsequently, we fine-tune it for specific downstream tasks, i.e., generating dance videos guided by pose sequences. 30 | 31 | In this project, we release all our models including the base mode and the fine-tuned model. You can download them from: 32 | * [Fintuned Model](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/finetune): We fine-tune our own trained base model for the downstream task using our proposed method, incorporating only `50M` learnable parameters. For your convenience, we directly merge the pretrained base model with the fine-tuned parameters, and you can download this consolidated model. 33 | * [Continuously Pretrained Model](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/pretrained): We continuously pretrain the base model using our collected human-related data, based on [SVD-XT1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). This approach improves performance in generating human features, particularly for hands and faces. However, due to the complexity of human motion, it still faces challenges similar to SVD in preferring to generate static videos. Nonetheless, it excels in downstream tasks. We encourage more participation to further enhance the base model for generating human-related videos. 34 | * [Fintuned Parameters](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/learned_params): The parameters involved in the fine-tuning process. Or you can directly download `Fintuned Model`. 35 | 36 | 37 | # Inference 38 | 39 | 1. Clone our repository 40 | 2. `cd ControlNeXt-SVD` 41 | 3. Download the pretrained weight into `pretrained/` from [here](https://huggingface.co/Pbihao/ControlNeXt/tree/main/ControlNeXt-SVD/finetune). (More details please refer to [Base Model](#base-model)) 42 | 4. Run the scipt 43 | 44 | ```python 45 | python run_controlnext.py \ 46 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \ 47 | --validation_control_video_path examples/pose/pose.mp4 \ 48 | --output_dir outputs/tiktok \ 49 | --controlnext_path pretrained/controlnet.bin \ 50 | --unet_path pretrained/unet_fp16.bin \ 51 | --ref_image_path examples/ref_imgs/tiktok.png 52 | ``` 53 | 54 | > --pretrained_model_name_or_path : pretrained base model, we pretrain and fintune models based on [SVD-XT1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1)\ 55 | > --controlnet_model_name_or_path : the model path of controlnet (a light weight module) \ 56 | > --unet_model_name_or_path : the model path of unet 57 | 58 | 5. Face Enhancement (Optional,Recommand for bad faces) 59 | 60 | > Currently, the model is not specifically trained for IP consistency, as there are already many mature tools available. Additionally, alternatives like Animate Anyone also adopt such post-processing techniques. 61 | 62 | a. Clone [Face Fusion](https://github.com/facefusion/facefusion): \ 63 | ```git clone https://github.com/facefusion/facefusion``` 64 | 65 | b. Ensure to enter the directory:\ 66 | ```cd facefusion``` 67 | 68 | c. Install facefusion (Recommand create a new virtual environment using conda to avoid conflicts):\ 69 | ```python install.py``` 70 | 71 | d. Run the command: 72 | ``` 73 | python run.py \ 74 | -s ../outputs/collected/demo.jpg \ 75 | -t ../outputs/collected/demo.mp4 \ 76 | -o ../outputs/collected/out.mp4 \ 77 | --headless \ 78 | --execution-providers cuda \ 79 | --face-selector-mode one 80 | ``` 81 | 82 | > -s: the reference image \ 83 | > -t: the path to the original video\ 84 | > -o: the path to store the refined video\ 85 | > --headless: no gui need\ 86 | > --execution-providers cuda: use cuda for acceleration (If available, most the cpu is enough) 87 | 88 | # Advanced Performance 89 | In this section, we will delve into additional details and my own experiences to enhance video generation. These factors are algorithm-independent and unrelated to academia, yet crucial for achieving superior results. Many closely related works incorporate these strategies. 90 | 91 | ### Reference Image 92 | 93 | It is crucial to ensure that the reference image is clear and easily understandable, especially aligning the face of the reference with the pose. 94 | 95 | 96 | ### Face Enhencement 97 | 98 | Most related works utilize face enhancement as part of the post-processing. This is especially relevant when generating videos based on images of unfamiliar individuals, such as friends, who were not included in the base model's pretraining and are therefore unseen and OOD data. 99 | 100 | We recommand the [Facefusion](https://github.com/facefusion/facefusion 101 | ) for the post proct-processing. And please let us know if you have a better solution. 102 | 103 | Please refer to [Facefusion](https://github.com/facefusion/facefusion 104 | ) for more details. 105 | 106 | ![Facefusion](examples/facefusion/facefusion.jpg) 107 | 108 | 109 | ### Continuously Finetune 110 | 111 | To significantly enhance performance on a specific pose sequence, you can continuously fine-tune the model for just a few hundred steps. 112 | 113 | We will release the related fine-tuning code later. 114 | 115 | ### Pose Generation 116 | 117 | We adopt [DWPose](https://github.com/IDEA-Research/DWPose) for the pose generation. 118 | 119 | # Limitations 120 | 121 | ## IP Consistency 122 | 123 | We did not prioritize maintaining IP consistency during the development of the generation model and now rely on a helper model for face enhancement. 124 | 125 | However, additional training can be implemented to ensure IP consistency moving forward. 126 | 127 | This also leaves a possible direction for futher improvement. 128 | 129 | ## Base model 130 | 131 | The base model plays a crucial role in generating human features, particularly hands and faces. We encourage collaboration to improve the base model for enhanced human-related video generation. 132 | 133 | # TODO 134 | 135 | * Training and finetune code 136 | -------------------------------------------------------------------------------- /ControlNeXt-SVD/examples/facefusion/facefusion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/examples/facefusion/facefusion.jpg -------------------------------------------------------------------------------- /ControlNeXt-SVD/examples/pose/pose.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/examples/pose/pose.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD/examples/ref_imgs/spiderman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/examples/ref_imgs/spiderman.jpg -------------------------------------------------------------------------------- /ControlNeXt-SVD/examples/ref_imgs/tiktok.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/examples/ref_imgs/tiktok.png -------------------------------------------------------------------------------- /ControlNeXt-SVD/outputs/chair/chair.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/chair/chair.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD/outputs/collected/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/collected/demo.jpg -------------------------------------------------------------------------------- /ControlNeXt-SVD/outputs/collected/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/collected/demo.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD/outputs/collected/out2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/collected/out2.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD/outputs/spiderman/spiderman.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/spiderman/spiderman.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD/outputs/star/star.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/star/star.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD/outputs/tiktok/tiktok.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/ControlNeXt/ab4b3acf912cc178d23bbf003369dfb657fc8d01/ControlNeXt-SVD/outputs/tiktok/tiktok.mp4 -------------------------------------------------------------------------------- /ControlNeXt-SVD/run_controlnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt 6 | from models.controlnext_vid_svd import ControlNeXtSDVModel 7 | from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel 8 | from transformers import CLIPVisionModelWithProjection 9 | import re 10 | from diffusers import AutoencoderKLTemporalDecoder 11 | from moviepy.editor import ImageSequenceClip 12 | from decord import VideoReader 13 | import argparse 14 | from safetensors.torch import load_file 15 | 16 | 17 | def write_mp4(video_path, samples, fps=14): 18 | clip = ImageSequenceClip(samples, fps=fps) 19 | clip.write_videofile(video_path, audio_codec="aac") 20 | 21 | def save_vid_side_by_side(batch_output, validation_control_images, output_folder, fps): 22 | # Helper function to convert tensors to PIL images and save as GIF 23 | flattened_batch_output = [img for sublist in batch_output for img in sublist] 24 | video_path = output_folder+'/test_1.mp4' 25 | final_images = [] 26 | outputs = [] 27 | # Helper function to concatenate images horizontally 28 | def get_concat_h(im1, im2): 29 | dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height))) 30 | dst.paste(im1, (0, 0)) 31 | dst.paste(im2, (im1.width, 0)) 32 | return dst 33 | for image_list in zip(validation_control_images, flattened_batch_output): 34 | predict_img = image_list[1].resize(image_list[0].size) 35 | result = get_concat_h(image_list[0], predict_img) 36 | final_images.append(np.array(result)) 37 | outputs.append(np.array(predict_img)) 38 | write_mp4(video_path, final_images, fps=fps) 39 | 40 | output_path = output_folder + "/output.mp4" 41 | write_mp4(output_path, outputs, fps=fps) 42 | 43 | 44 | def load_images_from_folder_to_pil(folder): 45 | images = [] 46 | valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed 47 | 48 | # Function to extract frame number from the filename 49 | def frame_number(filename): 50 | # First, try the pattern 'frame_x_7fps' 51 | new_pattern_match = re.search(r'frame_(\d+)_7fps', filename) 52 | if new_pattern_match: 53 | return int(new_pattern_match.group(1)) 54 | # If the new pattern is not found, use the original digit extraction method 55 | matches = re.findall(r'\d+', filename) 56 | if matches: 57 | if matches[-1] == '0000' and len(matches) > 1: 58 | return int(matches[-2]) # Return the second-to-last sequence if the last is '0000' 59 | return int(matches[-1]) # Otherwise, return the last sequence 60 | return float('inf') # Return 'inf' 61 | 62 | # Sorting files based on frame number 63 | sorted_files = sorted(os.listdir(folder), key=frame_number) 64 | # Load images in sorted order 65 | for filename in sorted_files: 66 | ext = os.path.splitext(filename)[1].lower() 67 | if ext in valid_extensions: 68 | img = Image.open(os.path.join(folder, filename)).convert('RGB') 69 | images.append(img) 70 | 71 | return images 72 | 73 | 74 | def load_images_from_video_to_pil(video_path): 75 | images = [] 76 | 77 | vr = VideoReader(video_path) 78 | length = len(vr) 79 | 80 | for idx in range(length): 81 | frame = vr[idx].asnumpy() 82 | images.append(Image.fromarray(frame)) 83 | return images 84 | 85 | 86 | def parse_args(): 87 | parser = argparse.ArgumentParser( 88 | description="Script to train Stable Diffusion XL for InstructPix2Pix." 89 | ) 90 | 91 | parser.add_argument( 92 | "--pretrained_model_name_or_path", 93 | type=str, 94 | default=None, 95 | required=True 96 | ) 97 | 98 | parser.add_argument( 99 | "--validation_control_images_folder", 100 | type=str, 101 | default=None, 102 | required=False, 103 | ) 104 | 105 | parser.add_argument( 106 | "--validation_control_video_path", 107 | type=str, 108 | default=None, 109 | required=False, 110 | ) 111 | 112 | parser.add_argument( 113 | "--output_dir", 114 | type=str, 115 | default=None, 116 | required=True 117 | ) 118 | 119 | parser.add_argument( 120 | "--height", 121 | type=int, 122 | default=768, 123 | required=False 124 | ) 125 | 126 | parser.add_argument( 127 | "--width", 128 | type=int, 129 | default=512, 130 | required=False 131 | ) 132 | 133 | parser.add_argument( 134 | "--guidance_scale", 135 | type=float, 136 | default=3.5, 137 | required=False 138 | ) 139 | 140 | parser.add_argument( 141 | "--num_inference_steps", 142 | type=int, 143 | default=25, 144 | required=False 145 | ) 146 | 147 | parser.add_argument( 148 | "--fps", 149 | type=int, 150 | default=14, 151 | required=False 152 | ) 153 | 154 | parser.add_argument( 155 | "--controlnext_path", 156 | type=str, 157 | default=None, 158 | required=True 159 | ) 160 | 161 | parser.add_argument( 162 | "--unet_path", 163 | type=str, 164 | default=None, 165 | required=True 166 | ) 167 | 168 | parser.add_argument( 169 | "--ref_image_path", 170 | type=str, 171 | default=None, 172 | required=True 173 | ) 174 | 175 | args = parser.parse_args() 176 | return args 177 | 178 | 179 | def load_tensor(tensor_path): 180 | if os.path.splitext(tensor_path)[1] == '.bin': 181 | return torch.load(tensor_path) 182 | elif os.path.splitext(tensor_path)[1] == ".safetensors": 183 | return load_file(tensor_path) 184 | else: 185 | print("without supported tensors") 186 | os._exit() 187 | 188 | 189 | # Main script 190 | if __name__ == "__main__": 191 | args = parse_args() 192 | 193 | assert (args.validation_control_images_folder is None) ^ (args.validation_control_video_path is None), "must and only one of [validation_control_images_folder, validation_control_video_path] should be given" 194 | if args.validation_control_images_folder is not None: 195 | validation_control_images = load_images_from_folder_to_pil(args.validation_control_images_folder) 196 | else: 197 | validation_control_images = load_images_from_video_to_pil(args.validation_control_video_path) 198 | 199 | unet = UNetSpatioTemporalConditionControlNeXtModel.from_pretrained( 200 | args.pretrained_model_name_or_path, 201 | subfolder="unet", 202 | low_cpu_mem_usage=True, 203 | variant="fp16", 204 | ) 205 | controlnext = ControlNeXtSDVModel() 206 | controlnext.load_state_dict(load_tensor(args.controlnext_path)) 207 | unet.load_state_dict(load_tensor(args.unet_path), strict=False) 208 | 209 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 210 | args.pretrained_model_name_or_path, subfolder="image_encoder", variant="fp16") 211 | vae = AutoencoderKLTemporalDecoder.from_pretrained( 212 | args.pretrained_model_name_or_path, subfolder="vae", variant="fp16") 213 | 214 | pipeline = StableVideoDiffusionPipelineControlNeXt.from_pretrained( 215 | args.pretrained_model_name_or_path, 216 | controlnext=controlnext, 217 | unet=unet, 218 | vae=vae, 219 | image_encoder=image_encoder) 220 | pipeline.to(dtype=torch.float16) 221 | pipeline.enable_model_cpu_offload() 222 | 223 | os.makedirs(args.output_dir, exist_ok=True) 224 | 225 | # Inference and saving loop 226 | final_result = [] 227 | ref_image = Image.open(args.ref_image_path).convert('RGB') 228 | frames = 14 229 | num_frames = len(validation_control_images) 230 | 231 | video_frames = pipeline( 232 | ref_image, 233 | validation_control_images[:num_frames], 234 | decode_chunk_size=4, 235 | num_frames=num_frames, 236 | motion_bucket_id=127.0, 237 | fps=7, 238 | controlnext_cond_scale=1.0, 239 | width=args.width, 240 | height=args.height, 241 | min_guidance_scale=args.guidance_scale, 242 | max_guidance_scale=args.guidance_scale, 243 | frames_per_batch=frames, 244 | num_inference_steps=args.num_inference_steps, 245 | overlap=4).frames[0] 246 | final_result.append(video_frames) 247 | 248 | save_vid_side_by_side( 249 | final_result, 250 | validation_control_images[:num_frames], 251 | args.output_dir, 252 | fps=args.fps) 253 | -------------------------------------------------------------------------------- /ControlNeXt-SVD/script.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | python run_controlnext.py \ 4 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \ 5 | --validation_control_video_path examples/pose/pose.mp4 \ 6 | --output_dir outputs/tiktok \ 7 | --controlnext_path pretrained/controlnet.bin \ 8 | --unet_path pretrained/unet_fp16.bin \ 9 | --ref_image_path examples/ref_imgs/tiktok.png 10 | 11 | 12 | python run_controlnext.py \ 13 | --pretrained_model_name_or_path stabilityai/stable-video-diffusion-img2vid-xt-1-1 \ 14 | --validation_control_video_path examples/pose/pose.mp4 \ 15 | --output_dir outputs/spiderman \ 16 | --controlnext_path pretrained/controlnet.bin \ 17 | --unet_path pretrained/unet_fp16.bin \ 18 | --ref_image_path examples/ref_imgs/spiderman.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # 🌀 ControlNeXt 3 | 4 | 5 | 6 | ## [📝 Project Page](https://pbihao.github.io/projects/controlnext/index.html) | [📚 Paper](https://arxiv.org/abs/2408.06070) | [🗂️ Demo (SDXL)](https://huggingface.co/spaces/Eugeoter/ControlNeXt) 7 | 8 | 9 | **ControlNeXt** is our official implementation for controllable generation, supporting both images and videos while incorporating diverse forms of control information. In this project, we propose a new method that reduces trainable parameters by up to 90% compared with ControlNet, achieving faster convergence and outstanding efficiency. This method can be directly combined with other LoRA techniques to alter style and ensure more stable generation. Please refer to the examples for more details. 10 | 11 | We provide an online demo of [ControlNeXt-SDXL](./ControlNeXt-SDXL/). Due to the high resource requirements of SVD, we are unable to offer it online. 12 | 13 | > This project is still undergoing iterative development. The code and model may be updated at any time. More information will be provided later. 14 | 15 | # Experiences 16 | We share more training experiences [there](./experiences.md) and in the [Issue](https://github.com/dvlab-research/ControlNeXt/issues/14#issuecomment-2290450333). 17 | We spent a lot of time to find these. Now share with all of you. May these will help you! 18 | 19 | # Model Zoo 20 | 21 | - **ControlNeXt-SDXL** [ [Link](ControlNeXt-SDXL) ] : Controllable image generation. Our model is built upon [Stable Diffusion XL ](stabilityai/stable-diffusion-xl-base-1.0). Fewer trainable parameters, faster convergence, improved efficiency, and can be integrated with LoRA. 22 | 23 | - **ControlNeXt-SDXL-Training** [ [Link](ControlNeXt-SDXL-Training) ] : The training scripts for our `ControlNeXt-SDXL` [ [Link](ControlNeXt-SDXL) ]. 24 | 25 | - **ControlNeXt-SVD-v2** [ [Link](ControlNeXt-SVD-v2) ] : Generate the video controlled by the sequence of human poses. In the v2 version, we implement several improvements: a higher-quality collected training dataset, larger training and inference batch frames, higher generation resolution, enhanced human-related video generation through continual training, and pose alignment for inference to improve overall performance. 26 | 27 | - **ControlNeXt-SVD-v2-Training** [ [Link](ControlNeXt-SVD-v2-Training) ] : The training scripts for our `ControlNeXt-SVD-v2` [ [Link](ControlNeXt-SVD-v2) ]. 28 | 29 | - **ControlNeXt-SVD** [ [Link](ControlNeXt-SVD) ] : Generate the video controlled by the sequence of human poses. This can be seen as an attempt to replicate the implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone). However, our model is built upon [Stable Video Diffusion](https://stability.ai/stable-video), employing a more concise architecture. 30 | 31 | - **ControlNeXt-SD1.5** [ [Link](ControlNeXt-SD1.5) ] : Controllable image generation. Our model is built upon [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Fewer trainable parameters, faster convergence, improved efficiency, and can be integrated with LoRA. 32 | 33 | - **ControlNeXt-SD1.5-Training** [ [Link](ControlNeXt-SD1.5-Training) ] : The training scripts for our `ControlNeXt-SD1.5` [ [Link](ControlNeXt-SD1.5) ]. 34 | 35 | - **ControlNeXt-SD3** [ [Link](ControlNeXt-SD3) ] : We are regret to inform that ControlNeXt-SD3 is trained with protected and private data and code, and therefore cannot be released. 36 | 37 | 38 | 39 | # 🎥 Examples 40 | ### For more examples, please refer to our [Project page](https://pbihao.github.io/projects/controlnext/index.html). 41 | 42 | ### [ControlNeXt-SDXL](ControlNeXt-SDXL) 43 | 44 |

45 | demo1 46 | demo2 47 | demo3 48 | demo5 49 |

50 | 51 | ### [ControlNeXt-SVD-v2](ControlNeXt-SVD-v2) 52 | If you can't load the videos, you can also directly download them from [here](examples/demos) and [here](examples/video). 53 | Or you can view them from our [Project Page](https://pbihao.github.io/projects/controlnext/index.html) or [BiliBili](https://www.bilibili.com/video/BV1wJYbebEE7/?buvid=YC4E03C93B119ADD4080B0958DE73F9DDCAC&from_spmid=dt.dt.video.0&is_story_h5=false&mid=y82Gz7uArS6jTQ6zuqJj3w%3D%3D&p=1&plat_id=114&share_from=ugc&share_medium=iphone&share_plat=ios&share_session_id=4E5549FC-0710-4030-BD2C-CDED80B46D08&share_source=WEIXIN&share_source=weixin&share_tag=s_i×tamp=1723123770&unique_k=XLZLhCq&up_id=176095810&vd_source=3791450598e16da25ecc2477fc7983db). 54 | 55 | 56 | 57 | 60 | 63 | 64 | 65 | 68 | 71 | 72 | 73 |
58 | 59 | 61 | 62 |
66 | 67 | 69 | 70 |
74 | 75 | 76 | 77 | 78 | 79 | 80 | ### [ControlNeXt-SVD](ControlNeXt-SVD) 81 | If you can't load the videos, you can also directly download them from [here](ControlNeXt-SVD/outputs). 82 | 83 | 84 | 85 | 86 | 87 | 90 | 91 | 92 | 93 | 96 | 99 | 100 |
94 | 95 | 97 | 98 |
101 | 102 | 103 | 104 | ### [ControlNeXt-SD1.5](ControlNeXt-SD1.5) 105 | 106 |

107 | DreamShaper 108 |

109 |

110 | Anythingv3 111 |

112 |

113 | Anythingv3 114 |

115 | 116 | 117 | ### If you find this work useful, please consider citing: 118 | ``` 119 | @article{peng2024controlnext, 120 | title={ControlNeXt: Powerful and Efficient Control for Image and Video Generation}, 121 | author={Peng, Bohao and Wang, Jian and Zhang, Yuechen and Li, Wenbo and Yang, Ming-Chang and Jia, Jiaya}, 122 | journal={arXiv preprint arXiv:2408.06070}, 123 | year={2024} 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /compress_image.py: -------------------------------------------------------------------------------- 1 | # from PIL import Image 2 | # import cv2 3 | # import numpy as np 4 | 5 | # img_path = "/home/llm/bhpeng/github/ControlAny/ControlAny-SDXL/examples/vidit_depth/condition_0.png" 6 | # save_path = "/home/llm/bhpeng/github/ControlAny/ControlAny-SDXL/examples/vidit_depth/condition_02.png" 7 | 8 | # length = 1 9 | # select_id = [] 10 | 11 | # image = cv2.imread(img_path) 12 | # height, width, _ = image.shape 13 | # part_width = width // length 14 | 15 | # splited_imgs = [] 16 | # for i in range(length): 17 | # left = i * part_width 18 | # right = (i + 1) * part_width if i < length - 1 else width # 确保最后一个分块到图像右边界 19 | 20 | # split_img = image[:, left:right] 21 | # splited_imgs.append(split_img) 22 | 23 | # merge_imgs = [] 24 | # merge_imgs.append(splited_imgs[0]) 25 | # for i in select_id: 26 | # merge_imgs.append(splited_imgs[i]) 27 | # merge_imgs = np.concatenate(merge_imgs, axis=1) 28 | # print(merge_imgs.shape) 29 | # resized_img = cv2.resize(merge_imgs, (merge_imgs.shape[1]//2, merge_imgs.shape[0]//2), interpolation=cv2.INTER_AREA) 30 | # print(resized_img.shape) 31 | # cv2.imwrite(save_path, resized_img, [cv2.IMWRITE_JPEG_QUALITY, 85]) 32 | 33 | # img = cv2.resize(img, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_AREA) 34 | 35 | 36 | # from moviepy.editor import VideoFileClip 37 | # import moviepy.video.io.ffmpeg_writer as ffmpeg_writer 38 | 39 | 40 | # video_path = 'ControlNeXt-SVD/outputs/chair/chair.mp4' 41 | # clip = VideoFileClip(video_path) 42 | 43 | # gif_path = 'ControlNeXt-SVD/outputs/chair/chair.gif' 44 | # clip.write_gif(gif_path, fps=14, program='ffmpeg', opt="nq", fuzz=1, ) 45 | 46 | 47 | from PIL import Image 48 | import os 49 | 50 | def compress_image(input_path, output_path, quality=85): 51 | """ 52 | 压缩图片,同时尽可能保留质量。 53 | 54 | :param input_path: 原始图片路径 55 | :param output_path: 压缩后图片保存路径 56 | :param quality: 压缩质量,取值范围是 0 到 100,100 代表最高质量 57 | """ 58 | # 打开图片 59 | with Image.open(input_path) as img: 60 | # 确保图片是 RGB 模式 61 | if img.mode in ("RGBA", "P"): 62 | img = img.convert("RGB") 63 | 64 | # 保存压缩后的图片 65 | img.save(output_path, "JPEG", quality=quality, optimize=True) 66 | 67 | 68 | img_paths = [ 69 | 'ControlNeXt-SDXL/examples/demo/demo1.png', 70 | 'ControlNeXt-SDXL/examples/demo/demo3.png', 71 | 'ControlNeXt-SDXL/examples/demo/demo5.png' 72 | ] 73 | quality = 50 74 | 75 | for src_path in img_paths: 76 | dst_path = src_path 77 | src_path = os.path.join(os.path.split(src_path)[0], 'src_'+os.path.split(src_path)[1]) 78 | os.rename(dst_path, src_path) 79 | dst_path = '.'.join(dst_path.split('.')[:-1])+'.jpg' 80 | compress_image(src_path, dst_path, quality=quality) -------------------------------------------------------------------------------- /experiences.md: -------------------------------------------------------------------------------- 1 | # 🌀 ControlNeXt-Experiences 2 | 3 | As we all know, developing a high-quality model is not just an academic challenge; it also requires extensive engineering experience. Therefore, we are sharing the insights we gained during this project. These insights are the result of significant time and effort. If you find them helpful, please consider giving us a star or citing our work. 4 | 5 | May they will help you. 6 | 7 | ### 1. Human-related generation 8 | 9 | As I’ve mentioned, we only select a small subset of parameters, which is fully adapted to the SD1.5 and SDXL backbones. By training fewer than 100 million parameters, we still achieve excellent performance. But this is is not suitable for the SD3 and SVD training. This is because, after SDXL, Stability faced significant legal risks due to the generation of highly realistic human images. After that, they stopped refining their models on human-related data, such as SVD and SD3, to avoid potential risks. 10 | 11 | To achieve optimal performance, it's necessary to first continue training SVD and SD3 on human-related data to develop a robust backbone before fine-tuning. Of course, you can also combine the continual pretraining and finetuning (Open all the parameters to train. There will not be a significant differences.). So you can find that we direct provide the full SVD parameters. 12 | 13 | Although this may not be directly related to academia, it is crucial for achieving good performance. 14 | 15 | ### 2. Data 16 | 17 | Due to privacy policies, we are unable to share the data. However, data quality is crucial, as many videos on the internet are highly compressed. It’s important to focus on collecting high-quality data without compression. 18 | 19 | ### 3. Hands 20 | 21 | Generating hands is a challenging problem in both video and image generation. To address this, we focus on the following strategies: 22 | 23 | a. Use clear and high-quality data, which is crucial for accurate generation. 24 | 25 | b. Since the hands occupy a relatively small area, we apply a larger scale for the loss function specifically for this region to improve the generation quality. 26 | 27 | ### 4. Pose alignment 28 | 29 | Thanks [mimic](https://github.com/Tencent/MimicMotion). SVD performs poorly, especially with large motions. Therefore, it is important to avoid large movements and shifts. So please note that in [preprocess](https://github.com/dvlab-research/ControlNeXt/blob/main/ControlNeXt-SVD-v2/dwpose/preprocess.py), there is a alignment between the refenrece image and pose. This is crucial. 30 | 31 | 32 | ### 5. Control level 33 | 34 | You can find that we adopt a magic nuber when adding the conditions. 35 | 36 | Such as in `ControlNeXt-SVD-v2/models/unet_spatio_temporal_condition_controlnext.py`: 37 | ```python 38 | sample = sample + conditional_controls * scale * 0.2 39 | ``` 40 | 41 | You can notice that we time a `0.2`. This superparameter is used to adjust the control level: increasing this value will strengthen the control level. 42 | 43 | However, if this value is set too high, the control may become overly strong and may not be apparent in the final generated images. 44 | 45 | So you can adjust it to get a good result. In our experiences, for the dense controls such as super-resolution or depth, we need to set it as `1`. 46 | 47 | 48 | ### 6. Training parameters 49 | 50 | One of the most important findings is that directly training the base model yields better performance compared to methods like LoRA, Adapter, and others.Even when we train the base model, we only select a small subset of the pre-trained parameters. You can also adaptively adjust the number of selected parameters. For example, with high-quality data, having more trainable parameters can improve performance. However, this is a trade-off, and regardless of the approach, directly training the base model often yields the best results. 51 | 52 | 53 | ### If you find this work useful, please consider citing: 54 | ``` 55 | @article{peng2024controlnext, 56 | title={ControlNeXt: Powerful and Efficient Control for Image and Video Generation}, 57 | author={Peng, Bohao and Wang, Jian and Zhang, Yuechen and Li, Wenbo and Yang, Ming-Chang and Jia, Jiaya}, 58 | journal={arXiv preprint arXiv:2408.06070}, 59 | year={2024} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | moviepy 3 | opencv-python 4 | pillow 5 | numpy 6 | transformers 7 | diffusers 8 | safetensors 9 | peft 10 | decord 11 | einops --------------------------------------------------------------------------------