├── src
└── animatediff
│ ├── utils
│ ├── __init__.py
│ ├── wild_card.py
│ ├── mask_rembg.py
│ ├── mask_animseg.py
│ ├── civitai2config.py
│ ├── device.py
│ ├── pipeline.py
│ ├── huggingface.py
│ ├── convert_lora_safetensor_to_diffusers.py
│ ├── tagger.py
│ └── composite.py
│ ├── models
│ ├── __init__.py
│ ├── clip.py
│ └── resnet.py
│ ├── repo
│ └── .gitignore
│ ├── rife
│ ├── __init__.py
│ ├── ncnn.py
│ ├── rife.py
│ └── ffmpeg.py
│ ├── __main__.py
│ ├── ip_adapter
│ ├── __init__.py
│ └── resampler.py
│ ├── softmax_splatting
│ ├── correlation
│ │ └── README.md
│ └── README.md
│ ├── pipelines
│ ├── __init__.py
│ ├── context.py
│ ├── ti.py
│ └── lora.py
│ ├── dwpose
│ ├── wholebody.py
│ ├── __init__.py
│ └── onnxdet.py
│ ├── __init__.py
│ ├── schedulers.py
│ └── settings.py
├── setup.py
├── config
├── prompts
│ ├── ignore_tokens.txt
│ ├── to_8fps_Frames.bat
│ ├── concat_2horizontal.bat
│ ├── copy_png.bat
│ ├── 01-ToonYou.json
│ ├── 04-MajicMix.json
│ ├── 08-GhibliBackground.json
│ ├── 03-RcnzCartoon.json
│ ├── 06-Tusun.json
│ ├── 05-RealisticVision.json
│ ├── 07-FilmVelvia.json
│ ├── 02-Lyriel.json
│ ├── prompt_travel_multi_controlnet.json
│ └── img2img_sample.json
├── inference
│ ├── motion_sdxl.json
│ ├── default.json
│ ├── motion_v2.json
│ ├── sd15-unet3d.json
│ └── sd15-unet.json
└── GroundingDINO
│ ├── GroundingDINO_SwinB_cfg.py
│ └── GroundingDINO_SwinT_OGC.py
├── MANIFEST.in
├── scripts
├── download
│ ├── 11-ToonYou.sh
│ ├── 12-Lyriel.sh
│ ├── 14-MajicMix.sh
│ ├── 13-RcnzCartoon.sh
│ ├── 15-RealisticVision.sh
│ ├── 16-Tusun.sh
│ ├── 17-FilmVelvia.sh
│ ├── 18-GhibliBackground.sh
│ ├── 03-BaseSD.py
│ ├── 01-Motion-Modules.sh
│ ├── sd-models.aria2
│ └── 02-All-SD-Models.sh
└── test_persistent.py
├── data
└── models
│ ├── WD14tagger
│ └── model.onnx
│ └── README.md
├── .editorconfig
├── pyproject.toml
├── .pre-commit-config.yaml
├── setup.cfg
├── .vscode
└── settings.json
├── test.py
├── app.py
├── requirements.txt
├── README.md
├── .gitignore
└── example.md
/src/animatediff/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/animatediff/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/animatediff/repo/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup()
4 |
--------------------------------------------------------------------------------
/config/prompts/ignore_tokens.txt:
--------------------------------------------------------------------------------
1 | motion_blur
2 | blurry
3 | realistic
4 | depth_of_field
5 |
--------------------------------------------------------------------------------
/config/prompts/to_8fps_Frames.bat:
--------------------------------------------------------------------------------
1 | ffmpeg -i %1 -start_number 0 -vf "scale=512:768,fps=8" %%04d.png
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # setuptools_scm will grab all tracked files, minus these exclusions
2 | prune .vscode
3 |
--------------------------------------------------------------------------------
/src/animatediff/rife/__init__.py:
--------------------------------------------------------------------------------
1 | from .rife import app
2 |
3 | __all__ = [
4 | "app",
5 | ]
6 |
--------------------------------------------------------------------------------
/src/animatediff/__main__.py:
--------------------------------------------------------------------------------
1 | from animatediff.cli import cli
2 |
3 | if __name__ == "__main__":
4 | cli()
5 |
--------------------------------------------------------------------------------
/config/prompts/concat_2horizontal.bat:
--------------------------------------------------------------------------------
1 | ffmpeg -i %1 -i %2 -filter_complex "[0:v][1:v]hstack=inputs=2[v]" -map "[v]" -crf 15 2horizontal.mp4
--------------------------------------------------------------------------------
/scripts/download/11-ToonYou.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://civitai.com/api/download/models/78775 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3 |
--------------------------------------------------------------------------------
/scripts/download/12-Lyriel.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://civitai.com/api/download/models/72396 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3 |
--------------------------------------------------------------------------------
/scripts/download/14-MajicMix.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://civitai.com/api/download/models/79068 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3 |
--------------------------------------------------------------------------------
/data/models/WD14tagger/model.onnx:
--------------------------------------------------------------------------------
1 | ../../../../../.cache/huggingface/hub/models--SmilingWolf--wd-v1-4-moat-tagger-v2/blobs/b8cef913be4c9e8d93f9f903e74271416502ce0b4b04df0ff1e2f00df488aa03
--------------------------------------------------------------------------------
/scripts/download/13-RcnzCartoon.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://civitai.com/api/download/models/71009 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3 |
--------------------------------------------------------------------------------
/scripts/download/15-RealisticVision.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://civitai.com/api/download/models/29460 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3 |
--------------------------------------------------------------------------------
/data/models/README.md:
--------------------------------------------------------------------------------
1 | ## Folder that contains the weight
2 |
3 | Put the weights of the base model by creating a new 'huggingface' folder and that of the motion module by creating a new 'motion-module' folder
4 |
5 |
--------------------------------------------------------------------------------
/config/prompts/copy_png.bat:
--------------------------------------------------------------------------------
1 |
2 | setlocal enableDelayedExpansion
3 | FOR /l %%N in (1,1,%~n1) do (
4 | set "n=00000%%N"
5 | set "TEST=!n:~-5!
6 | echo !TEST!
7 | copy /y %1 !TEST!.png
8 | )
9 |
10 | ren %1 00000.png
11 |
12 |
--------------------------------------------------------------------------------
/scripts/download/16-Tusun.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://civitai.com/api/download/models/97261 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3 | wget https://civitai.com/api/download/models/50705 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
4 |
--------------------------------------------------------------------------------
/scripts/download/17-FilmVelvia.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://civitai.com/api/download/models/90115 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3 | wget https://civitai.com/api/download/models/92475 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
4 |
--------------------------------------------------------------------------------
/scripts/download/18-GhibliBackground.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://civitai.com/api/download/models/102828 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3 | wget https://civitai.com/api/download/models/57618 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
4 |
--------------------------------------------------------------------------------
/src/animatediff/ip_adapter/__init__.py:
--------------------------------------------------------------------------------
1 | from .ip_adapter import (IPAdapter, IPAdapterFull, IPAdapterPlus,
2 | IPAdapterPlusXL, IPAdapterXL)
3 |
4 | __all__ = [
5 | "IPAdapter",
6 | "IPAdapterPlus",
7 | "IPAdapterPlusXL",
8 | "IPAdapterXL",
9 | "IPAdapterFull",
10 | ]
11 |
--------------------------------------------------------------------------------
/src/animatediff/softmax_splatting/correlation/README.md:
--------------------------------------------------------------------------------
1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
--------------------------------------------------------------------------------
/src/animatediff/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .animation import AnimationPipeline, AnimationPipelineOutput
2 | from .context import get_context_scheduler, get_total_steps, ordered_halving, uniform
3 | from .ti import get_text_embeddings, load_text_embeddings
4 |
5 | __all__ = [
6 | "AnimationPipeline",
7 | "AnimationPipelineOutput",
8 | "get_context_scheduler",
9 | "get_total_steps",
10 | "ordered_halving",
11 | "uniform",
12 | "get_text_embeddings",
13 | "load_text_embeddings",
14 | ]
15 |
--------------------------------------------------------------------------------
/scripts/download/03-BaseSD.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from diffusers.pipelines import StableDiffusionPipeline
3 |
4 | from animatediff import get_dir
5 |
6 | out_dir = get_dir("data/models/huggingface/stable-diffusion-v1-5")
7 |
8 | pipeline = StableDiffusionPipeline.from_pretrained(
9 | "runwayml/stable-diffusion-v1-5",
10 | use_safetensors=True,
11 | kwargs=dict(safety_checker=None, requires_safety_checker=False),
12 | )
13 | pipeline.save_pretrained(
14 | save_directory=str(out_dir),
15 | safe_serialization=True,
16 | )
17 |
--------------------------------------------------------------------------------
/scripts/download/01-Motion-Modules.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | echo "Attempting download of Motion Module models from Google Drive."
4 | echo "If this fails, please download them manually from the links in the error messages/README."
5 |
6 | gdown 1RqkQuGPaCO5sGZ6V6KZ-jUWmsRu48Kdq -O models/motion-module/ || true
7 | gdown 1ql0g_Ys4UCz2RnokYlBjyOYPbttbIpbu -O models/motion-module/ || true
8 |
9 | echo "Motion module download script complete."
10 | echo "If you see errors above, please download the models manually from the links in the error messages/README."
11 | exit 0
12 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [*.{json,jsonc}]
18 | indent_style = space
19 | indent_size = 2
20 |
21 | [.vscode/*.{json,jsonc}]
22 | indent_style = space
23 | indent_size = 4
24 |
25 | [*.{yml,yaml,toml}]
26 | indent_style = space
27 | indent_size = 2
28 |
29 | [*.md]
30 | trim_trailing_whitespace = false
31 |
32 | [Makefile]
33 | indent_style = tab
34 | indent_size = 8
35 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "setuptools.build_meta"
3 | requires = ["setuptools>=46.4.0", "wheel", "setuptools_scm[toml]>=6.2"]
4 |
5 | [tool.setuptools_scm]
6 | write_to = "src/animatediff/_version.py"
7 |
8 | [tool.black]
9 | line-length = 110
10 | target-version = ['py310']
11 | ignore = ['F841', 'F401', 'E501']
12 | preview = true
13 |
14 | [tool.ruff]
15 | line-length = 110
16 | target-version = 'py310'
17 | ignore = ['F841', 'F401', 'E501']
18 |
19 | [tool.ruff.isort]
20 | combine-as-imports = true
21 | force-wrap-aliases = true
22 | known-local-folder = ["src"]
23 | known-first-party = ["animatediff"]
24 |
25 | [tool.pyright]
26 | include = ['src/**']
27 | exclude = ['/usr/lib/**']
28 |
--------------------------------------------------------------------------------
/config/inference/motion_sdxl.json:
--------------------------------------------------------------------------------
1 | {
2 | "unet_additional_kwargs": {
3 | "unet_use_temporal_attention": false,
4 | "use_motion_module": true,
5 | "motion_module_resolutions": [1, 2, 4, 8],
6 | "motion_module_mid_block": false,
7 | "motion_module_type": "Vanilla",
8 | "motion_module_kwargs": {
9 | "num_attention_heads": 8,
10 | "num_transformer_block": 1,
11 | "attention_block_types": ["Temporal_Self", "Temporal_Self"],
12 | "temporal_position_encoding": true,
13 | "temporal_position_encoding_max_len": 32,
14 | "temporal_attention_dim_div": 1
15 | }
16 | },
17 | "noise_scheduler_kwargs": {
18 | "num_train_timesteps": 1000,
19 | "beta_start": 0.00085,
20 | "beta_end": 0.020,
21 | "beta_schedule": "scaled_linear"
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | ci:
3 | autofix_prs: true
4 | autoupdate_branch: "main"
5 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate"
6 | autoupdate_schedule: weekly
7 |
8 | repos:
9 | - repo: https://github.com/astral-sh/ruff-pre-commit
10 | rev: "v0.0.281"
11 | hooks:
12 | - id: ruff
13 | args: ["--fix", "--exit-non-zero-on-fix"]
14 |
15 | - repo: https://github.com/psf/black
16 | rev: 23.7.0
17 | hooks:
18 | - id: black
19 | args: ["--line-length=110"]
20 |
21 | - repo: https://github.com/pre-commit/pre-commit-hooks
22 | rev: v4.4.0
23 | hooks:
24 | - id: trailing-whitespace
25 | args: [--markdown-linebreak-ext=md]
26 | - id: end-of-file-fixer
27 | - id: check-yaml
28 | - id: check-added-large-files
29 |
--------------------------------------------------------------------------------
/config/inference/default.json:
--------------------------------------------------------------------------------
1 | {
2 | "unet_additional_kwargs": {
3 | "unet_use_cross_frame_attention": false,
4 | "unet_use_temporal_attention": false,
5 | "use_motion_module": true,
6 | "motion_module_resolutions": [1, 2, 4, 8],
7 | "motion_module_mid_block": false,
8 | "motion_module_decoder_only": false,
9 | "motion_module_type": "Vanilla",
10 | "motion_module_kwargs": {
11 | "num_attention_heads": 8,
12 | "num_transformer_block": 1,
13 | "attention_block_types": ["Temporal_Self", "Temporal_Self"],
14 | "temporal_position_encoding": true,
15 | "temporal_position_encoding_max_len": 24,
16 | "temporal_attention_dim_div": 1
17 | }
18 | },
19 | "noise_scheduler_kwargs": {
20 | "num_train_timesteps": 1000,
21 | "beta_start": 0.00085,
22 | "beta_end": 0.012,
23 | "beta_schedule": "linear",
24 | "steps_offset": 1,
25 | "clip_sample": false
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/config/inference/motion_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "unet_additional_kwargs": {
3 | "use_inflated_groupnorm": true,
4 | "unet_use_cross_frame_attention": false,
5 | "unet_use_temporal_attention": false,
6 | "use_motion_module": true,
7 | "motion_module_resolutions": [1, 2, 4, 8],
8 | "motion_module_mid_block": true,
9 | "motion_module_decoder_only": false,
10 | "motion_module_type": "Vanilla",
11 | "motion_module_kwargs": {
12 | "num_attention_heads": 8,
13 | "num_transformer_block": 1,
14 | "attention_block_types": ["Temporal_Self", "Temporal_Self"],
15 | "temporal_position_encoding": true,
16 | "temporal_position_encoding_max_len": 32,
17 | "temporal_attention_dim_div": 1
18 | }
19 | },
20 | "noise_scheduler_kwargs": {
21 | "num_train_timesteps": 1000,
22 | "beta_start": 0.00085,
23 | "beta_end": 0.012,
24 | "beta_schedule": "linear",
25 | "steps_offset": 1,
26 | "clip_sample": false
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/scripts/test_persistent.py:
--------------------------------------------------------------------------------
1 | from rich import print
2 |
3 | from animatediff import get_dir
4 | from animatediff.cli import generate, logger
5 |
6 | config_dir = get_dir("config")
7 |
8 | config_path = config_dir.joinpath("prompts/test.json")
9 | width = 512
10 | height = 512
11 | length = 32
12 | context = 16
13 | stride = 4
14 |
15 | logger.warn("Running first-round generation test, this should load the full model.\n\n")
16 | out_dir = generate(
17 | config_path=config_path,
18 | width=width,
19 | height=height,
20 | length=length,
21 | context=context,
22 | stride=stride,
23 | )
24 | logger.warn(f"Generated animation to {out_dir}")
25 |
26 | logger.warn("\n\nRunning second-round generation test, this should reuse the already loaded model.\n\n")
27 | out_dir = generate(
28 | config_path=config_path,
29 | width=width,
30 | height=height,
31 | length=length,
32 | context=context,
33 | stride=stride,
34 | )
35 | logger.warn(f"Generated animation to {out_dir}")
36 |
37 | logger.error("If the second round didn't talk about reloading the model, it worked! yay!")
38 |
--------------------------------------------------------------------------------
/scripts/download/sd-models.aria2:
--------------------------------------------------------------------------------
1 | https://civitai.com/api/download/models/78775
2 | out=models/sd/toonyou_beta3.safetensors
3 | https://civitai.com/api/download/models/72396
4 | out=models/sd/lyriel_v16.safetensors
5 | https://civitai.com/api/download/models/71009
6 | out=models/sd/rcnzCartoon3d_v10.safetensors
7 | https://civitai.com/api/download/models/79068
8 | out=majicmixRealistic_v5Preview.safetensors
9 | https://civitai.com/api/download/models/29460
10 | out=models/sd/realisticVisionV40_v20Novae.safetensors
11 | https://civitai.com/api/download/models/97261
12 | out=models/sd/TUSUN.safetensors
13 | https://civitai.com/api/download/models/50705
14 | out=models/sd/leosamsMoonfilm_reality20.safetensors
15 | https://civitai.com/api/download/models/90115
16 | out=models/sd/FilmVelvia2.safetensors
17 | https://civitai.com/api/download/models/92475
18 | out=models/sd/leosamsMoonfilm_filmGrain10.safetensors
19 | https://civitai.com/api/download/models/102828
20 | out=models/sd/Pyramid\ lora_Ghibli_n3.safetensors
21 | https://civitai.com/api/download/models/57618
22 | out=models/sd/CounterfeitV30_v30.safetensors
23 |
--------------------------------------------------------------------------------
/config/GroundingDINO/GroundingDINO_SwinB_cfg.py:
--------------------------------------------------------------------------------
1 | batch_size = 1
2 | modelname = "groundingdino"
3 | backbone = "swin_B_384_22k"
4 | position_embedding = "sine"
5 | pe_temperatureH = 20
6 | pe_temperatureW = 20
7 | return_interm_indices = [1, 2, 3]
8 | backbone_freeze_keywords = None
9 | enc_layers = 6
10 | dec_layers = 6
11 | pre_norm = False
12 | dim_feedforward = 2048
13 | hidden_dim = 256
14 | dropout = 0.0
15 | nheads = 8
16 | num_queries = 900
17 | query_dim = 4
18 | num_patterns = 0
19 | num_feature_levels = 4
20 | enc_n_points = 4
21 | dec_n_points = 4
22 | two_stage_type = "standard"
23 | two_stage_bbox_embed_share = False
24 | two_stage_class_embed_share = False
25 | transformer_activation = "relu"
26 | dec_pred_bbox_embed_share = True
27 | dn_box_noise_scale = 1.0
28 | dn_label_noise_ratio = 0.5
29 | dn_label_coef = 1.0
30 | dn_bbox_coef = 1.0
31 | embed_init_tgt = True
32 | dn_labelbook_size = 2000
33 | max_text_len = 256
34 | text_encoder_type = "bert-base-uncased"
35 | use_text_enhancer = True
36 | use_fusion_layer = True
37 | use_checkpoint = True
38 | use_transformer_ckpt = True
39 | use_text_cross_attention = True
40 | text_dropout = 0.0
41 | fusion_dropout = 0.0
42 | fusion_droppath = 0.1
43 | sub_sentence_present = True
44 |
--------------------------------------------------------------------------------
/config/GroundingDINO/GroundingDINO_SwinT_OGC.py:
--------------------------------------------------------------------------------
1 | batch_size = 1
2 | modelname = "groundingdino"
3 | backbone = "swin_T_224_1k"
4 | position_embedding = "sine"
5 | pe_temperatureH = 20
6 | pe_temperatureW = 20
7 | return_interm_indices = [1, 2, 3]
8 | backbone_freeze_keywords = None
9 | enc_layers = 6
10 | dec_layers = 6
11 | pre_norm = False
12 | dim_feedforward = 2048
13 | hidden_dim = 256
14 | dropout = 0.0
15 | nheads = 8
16 | num_queries = 900
17 | query_dim = 4
18 | num_patterns = 0
19 | num_feature_levels = 4
20 | enc_n_points = 4
21 | dec_n_points = 4
22 | two_stage_type = "standard"
23 | two_stage_bbox_embed_share = False
24 | two_stage_class_embed_share = False
25 | transformer_activation = "relu"
26 | dec_pred_bbox_embed_share = True
27 | dn_box_noise_scale = 1.0
28 | dn_label_noise_ratio = 0.5
29 | dn_label_coef = 1.0
30 | dn_bbox_coef = 1.0
31 | embed_init_tgt = True
32 | dn_labelbook_size = 2000
33 | max_text_len = 256
34 | text_encoder_type = "bert-base-uncased"
35 | use_text_enhancer = True
36 | use_fusion_layer = True
37 | use_checkpoint = True
38 | use_transformer_ckpt = True
39 | use_text_cross_attention = True
40 | text_dropout = 0.0
41 | fusion_dropout = 0.0
42 | fusion_droppath = 0.1
43 | sub_sentence_present = True
44 |
--------------------------------------------------------------------------------
/config/prompts/01-ToonYou.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "ToonYou",
3 | "base": "",
4 | "path": "models/sd/toonyou_beta3.safetensors",
5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt",
6 | "compile": false,
7 | "seed": [
8 | 10788741199826055000, 6520604954829637000, 6519455744612556000,
9 | 16372571278361864000
10 | ],
11 | "scheduler": "k_dpmpp",
12 | "steps": 30,
13 | "guidance_scale": 8.5,
14 | "clip_skip": 2,
15 | "prompt": [
16 | "1girl, solo, best quality, masterpiece, looking at viewer, purple hair, orange hair, gradient hair, blurry background, upper body, dress, flower print, spaghetti strap, bare shoulders",
17 | "1girl, solo, masterpiece, best quality, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,",
18 | "1girl, solo, best quality, masterpiece, looking at viewer, purple hair, orange hair, gradient hair, blurry background, upper body, dress, flower print, spaghetti strap, bare shoulders",
19 | "1girl, solo, best quality, masterpiece, cloudy sky, dandelion, contrapposto, alternate hairstyle"
20 | ],
21 | "n_prompt": [
22 | "worst quality, low quality, cropped, lowres, text, jpeg artifacts, multiple view"
23 | ]
24 | }
25 |
--------------------------------------------------------------------------------
/src/animatediff/utils/wild_card.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 | import re
5 |
6 | wild_card_regex = r'(\A|\W)__([\w-]+)__(\W|\Z)'
7 |
8 |
9 | def create_wild_card_map(wild_card_dir):
10 | result = {}
11 | if os.path.isdir(wild_card_dir):
12 | txt_list = glob.glob( os.path.join(wild_card_dir ,"**/*.txt"), recursive=True)
13 | for txt in txt_list:
14 | basename_without_ext = os.path.splitext(os.path.basename(txt))[0]
15 | with open(txt, encoding='utf-8') as f:
16 | try:
17 | result[basename_without_ext] = [s.rstrip() for s in f.readlines()]
18 | except Exception as e:
19 | print(e)
20 | print("can not read ", txt)
21 | return result
22 |
23 | def replace_wild_card_token(match_obj, wild_card_map):
24 | m1 = match_obj.group(1)
25 | m3 = match_obj.group(3)
26 |
27 | dict_name = match_obj.group(2)
28 |
29 | if dict_name in wild_card_map:
30 | token_list = wild_card_map[dict_name]
31 | token = token_list[random.randint(0,len(token_list)-1)]
32 | return m1+token+m3
33 | else:
34 | return match_obj.group(0)
35 |
36 | def replace_wild_card(prompt, wild_card_dir):
37 | wild_card_map = create_wild_card_map(wild_card_dir)
38 | prompt = re.sub(wild_card_regex, lambda x: replace_wild_card_token(x, wild_card_map ), prompt)
39 | return prompt
40 |
--------------------------------------------------------------------------------
/config/prompts/04-MajicMix.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "MajicMix",
3 | "base": "",
4 | "path": "models/sd/majicmixRealistic_v5Preview.safetensors",
5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt",
6 | "seed": [
7 | 1572448948722921000, 1099474677988590700, 6488833139725636000,
8 | 18339859844376519000
9 | ],
10 | "scheduler": "k_dpmpp",
11 | "steps": 25,
12 | "guidance_scale": 7.5,
13 | "prompt": [
14 | "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic",
15 | "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting",
16 | "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below",
17 | "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic"
18 | ],
19 | "n_prompt": [
20 | "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles",
21 | "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome",
22 | "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome",
23 | "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people"
24 | ]
25 | }
26 |
--------------------------------------------------------------------------------
/config/prompts/08-GhibliBackground.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "GhibliBackground",
3 | "base": "models/sd/CounterfeitV30_25.safetensors",
4 | "path": "models/sd/lora_Ghibli_n3.safetensors",
5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt",
6 | "seed": [
7 | 8775748474469046000, 5893874876080607000, 11911465742147697000,
8 | 12437784838692000000
9 | ],
10 | "scheduler": "k_dpmpp",
11 | "steps": 25,
12 | "guidance_scale": 7.5,
13 | "lora_alpha": 1,
14 | "prompt": [
15 | "best quality,single build,architecture, blue_sky, building,cloudy_sky, day, fantasy, fence, field, house, build,architecture,landscape, moss, outdoors, overgrown, path, river, road, rock, scenery, sky, sword, tower, tree, waterfall",
16 | "black_border, building, city, day, fantasy, ice, landscape, letterboxed, mountain, ocean, outdoors, planet, scenery, ship, snow, snowing, water, watercraft, waterfall, winter",
17 | ",mysterious sea area, fantasy,build,concept",
18 | "Tomb Raider,Scenography,Old building"
19 | ],
20 | "n_prompt": [
21 | "easynegative,bad_construction,bad_structure,bad_wail,bad_windows,blurry,cloned_window,cropped,deformed,disfigured,error,extra_windows,extra_chimney,extra_door,extra_structure,extra_frame,fewer_digits,fused_structure,gross_proportions,jpeg_artifacts,long_roof,low_quality,structure_limbs,missing_windows,missing_doors,missing_roofs,mutated_structure,mutation,normal_quality,out_of_frame,owres,poorly_drawn_structure,poorly_drawn_house,signature,text,too_many_windows,ugly,username,uta,watermark,worst_quality"
22 | ]
23 | }
24 |
--------------------------------------------------------------------------------
/scripts/download/02-All-SD-Models.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -euo pipefail
3 |
4 | repo_dir=$(git rev-parse --show-toplevel)
5 | if [[ ! -d "${repo_dir}" ]]; then
6 | echo "Could not find the repo root. Checking for ./data/models/sd"
7 | repo_dir="."
8 | fi
9 |
10 | models_dir=$(realpath "${repo_dir}/data/models/sd")
11 | if [[ ! -d "${models_dir}" ]]; then
12 | echo "Could not find repo root or models directory."
13 | echo "Either create ./data/models/sd or run this script from a checked-out git repo."
14 | exit 1
15 | fi
16 |
17 | model_urls=(
18 | https://civitai.com/api/download/models/78775 # ToonYou
19 | https://civitai.com/api/download/models/72396 # Lyriel
20 | https://civitai.com/api/download/models/71009 # RcnzCartoon
21 | https://civitai.com/api/download/models/79068 # MajicMix
22 | https://civitai.com/api/download/models/29460 # RealisticVision
23 | https://civitai.com/api/download/models/97261 # Tusun (1/2)
24 | https://civitai.com/api/download/models/50705 # Tusun (2/2)
25 | https://civitai.com/api/download/models/90115 # FilmVelvia (1/2)
26 | https://civitai.com/api/download/models/92475 # FilmVelvia (2/2)
27 | https://civitai.com/api/download/models/102828 # GhibliBackground (1/2)
28 | https://civitai.com/api/download/models/57618 # GhibliBackground (2/2)
29 | )
30 |
31 | echo "Downloading model files to ${models_dir}..."
32 |
33 | # Create the models directory if it doesn't exist
34 | mkdir -p "${models_dir}"
35 |
36 | # Download the models
37 | for url in ${model_urls[@]}; do
38 | curl -JLO --output-dir "${models_dir}" "${url}" || true
39 | done
40 |
--------------------------------------------------------------------------------
/src/animatediff/dwpose/wholebody.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | import cv2
3 | import numpy as np
4 | import onnxruntime as ort
5 |
6 | from .onnxdet import inference_detector
7 | from .onnxpose import inference_pose
8 |
9 |
10 | class Wholebody:
11 | def __init__(self, device='cuda:0'):
12 | providers = ['CPUExecutionProvider'
13 | ] if device == 'cpu' else ['CUDAExecutionProvider']
14 | onnx_det = 'data/models/DWPose/yolox_l.onnx'
15 | onnx_pose = 'data/models/DWPose/dw-ll_ucoco_384.onnx'
16 |
17 | self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
18 | self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
19 |
20 | def __call__(self, oriImg):
21 | det_result = inference_detector(self.session_det, oriImg)
22 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
23 |
24 | keypoints_info = np.concatenate(
25 | (keypoints, scores[..., None]), axis=-1)
26 | # compute neck joint
27 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
28 | # neck score when visualizing pred
29 | neck[:, 2:4] = np.logical_and(
30 | keypoints_info[:, 5, 2:4] > 0.3,
31 | keypoints_info[:, 6, 2:4] > 0.3).astype(int)
32 | new_keypoints_info = np.insert(
33 | keypoints_info, 17, neck, axis=1)
34 | mmpose_idx = [
35 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
36 | ]
37 | openpose_idx = [
38 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
39 | ]
40 | new_keypoints_info[:, openpose_idx] = \
41 | new_keypoints_info[:, mmpose_idx]
42 | keypoints_info = new_keypoints_info
43 |
44 | keypoints, scores = keypoints_info[
45 | ..., :2], keypoints_info[..., 2]
46 |
47 | return keypoints, scores
48 |
49 |
50 |
--------------------------------------------------------------------------------
/src/animatediff/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from ._version import (
3 | version as __version__,
4 | version_tuple,
5 | )
6 | except ImportError:
7 | __version__ = "unknown (no version information available)"
8 | version_tuple = (0, 0, "unknown", "noinfo")
9 |
10 | from functools import lru_cache
11 | from os import getenv
12 | from pathlib import Path
13 | from warnings import filterwarnings
14 |
15 | from rich.console import Console
16 | from tqdm import TqdmExperimentalWarning
17 |
18 | PACKAGE = __package__.replace("_", "-")
19 | PACKAGE_ROOT = Path(__file__).parent.parent
20 |
21 | HF_HOME = Path(getenv("HF_HOME", Path.home() / ".cache" / "huggingface"))
22 | HF_HUB_CACHE = Path(getenv("HUGGINGFACE_HUB_CACHE", HF_HOME.joinpath("hub")))
23 |
24 | HF_LIB_NAME = "animatediff-cli"
25 | HF_LIB_VER = __version__
26 | HF_MODULE_REPO = "neggles/animatediff-modules"
27 |
28 | console = Console(highlight=True)
29 | err_console = Console(stderr=True)
30 |
31 | # shhh torch, don't worry about it it's fine
32 | filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
33 | # you too tqdm
34 | filterwarnings("ignore", category=TqdmExperimentalWarning)
35 |
36 |
37 | @lru_cache(maxsize=4)
38 | def get_dir(dirname: str = "data") -> Path:
39 | if PACKAGE_ROOT.name == "src":
40 | # we're installed in editable mode from within the repo
41 | dirpath = PACKAGE_ROOT.parent.joinpath(dirname)
42 | else:
43 | # we're installed normally, so we just use the current working directory
44 | dirpath = Path.cwd().joinpath(dirname)
45 | dirpath.mkdir(parents=True, exist_ok=True)
46 | return dirpath.absolute()
47 |
48 |
49 | __all__ = [
50 | "__version__",
51 | "version_tuple",
52 | "PACKAGE",
53 | "PACKAGE_ROOT",
54 | "HF_HOME",
55 | "HF_HUB_CACHE",
56 | "console",
57 | "err_console",
58 | "get_dir",
59 | "models",
60 | "pipelines",
61 | "rife",
62 | "utils",
63 | "cli",
64 | "generate",
65 | "schedulers",
66 | "settings",
67 | ]
68 |
--------------------------------------------------------------------------------
/config/inference/sd15-unet3d.json:
--------------------------------------------------------------------------------
1 | {
2 | "sample_size": 64,
3 | "in_channels": 4,
4 | "out_channels": 4,
5 | "center_input_sample": false,
6 | "flip_sin_to_cos": true,
7 | "freq_shift": 0,
8 | "down_block_types": [
9 | "CrossAttnDownBlock3D",
10 | "CrossAttnDownBlock3D",
11 | "CrossAttnDownBlock3D",
12 | "DownBlock3D"
13 | ],
14 | "mid_block_type": "UNetMidBlock3DCrossAttn",
15 | "up_block_types": [
16 | "UpBlock3D",
17 | "CrossAttnUpBlock3D",
18 | "CrossAttnUpBlock3D",
19 | "CrossAttnUpBlock3D"
20 | ],
21 | "only_cross_attention": false,
22 | "block_out_channels": [320, 640, 1280, 1280],
23 | "layers_per_block": 2,
24 | "downsample_padding": 1,
25 | "mid_block_scale_factor": 1,
26 | "act_fn": "silu",
27 | "norm_num_groups": 32,
28 | "norm_eps": 1e-5,
29 | "cross_attention_dim": 768,
30 | "attention_head_dim": 8,
31 | "dual_cross_attention": false,
32 | "use_linear_projection": false,
33 | "class_embed_type": null,
34 | "num_class_embeds": null,
35 | "upcast_attention": false,
36 | "resnet_time_scale_shift": "default",
37 | "use_motion_module": true,
38 | "motion_module_resolutions": [1, 2, 4, 8],
39 | "motion_module_mid_block": false,
40 | "motion_module_decoder_only": false,
41 | "motion_module_type": "Vanilla",
42 | "motion_module_kwargs": {
43 | "num_attention_heads": 8,
44 | "num_transformer_block": 1,
45 | "attention_block_types": ["Temporal_Self", "Temporal_Self"],
46 | "temporal_position_encoding": true,
47 | "temporal_position_encoding_max_len": 24,
48 | "temporal_attention_dim_div": 1
49 | },
50 | "unet_use_cross_frame_attention": false,
51 | "unet_use_temporal_attention": false,
52 | "_use_default_values": [
53 | "use_linear_projection",
54 | "mid_block_type",
55 | "upcast_attention",
56 | "dual_cross_attention",
57 | "num_class_embeds",
58 | "only_cross_attention",
59 | "class_embed_type",
60 | "resnet_time_scale_shift"
61 | ],
62 | "_class_name": "UNet3DConditionModel",
63 | "_diffusers_version": "0.6.0"
64 | }
65 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = animatediff
3 | author = Andi Powers-Holmes
4 | email = aholmes@omnom.net
5 | maintainer = Andi Powers-Holmes
6 | maintainer_email = aholmes@omnom.net
7 | license_files = LICENSE.md
8 |
9 | [options]
10 | python_requires = >=3.10
11 | packages = find:
12 | package_dir =
13 | =src
14 | py_modules =
15 | animatediff
16 | include_package_data = True
17 | install_requires =
18 | accelerate >= 0.20.3
19 | colorama >= 0.4.3, < 0.5.0
20 | cmake >= 3.25.0
21 | diffusers == 0.23.0
22 | einops >= 0.6.1
23 | gdown >= 4.6.6
24 | ninja >= 1.11.0
25 | numpy >= 1.22.4
26 | omegaconf >= 2.3.0
27 | pillow >= 9.4.0, < 10.0.0
28 | pydantic >= 1.10.0, < 2.0.0
29 | rich >= 13.0.0, < 14.0.0
30 | safetensors >= 0.3.1
31 | sentencepiece >= 0.1.99
32 | shellingham >= 1.5.0, < 2.0.0
33 | torch >= 2.1.0, < 2.2.0
34 | torchaudio
35 | torchvision
36 | transformers >= 4.30.2, < 4.35.0
37 | typer >= 0.9.0, < 1.0.0
38 | controlnet_aux
39 | matplotlib
40 | ffmpeg-python >= 0.2.0
41 | mediapipe
42 | xformers >= 0.0.22.post7
43 | opencv-python
44 |
45 | [options.packages.find]
46 | where = src
47 |
48 | [options.package_data]
49 | * = *.txt, *.md
50 |
51 | [options.extras_require]
52 | dev =
53 | black >= 22.3.0
54 | ruff >= 0.0.234
55 | setuptools-scm >= 7.0.0
56 | pre-commit >= 3.3.0
57 | ipython
58 | rife =
59 | ffmpeg-python >= 0.2.0
60 | stylize =
61 | ffmpeg-python >= 0.2.0
62 | onnxruntime-gpu
63 | pandas
64 | opencv-python
65 | dwpose =
66 | onnxruntime-gpu
67 | stylize_mask =
68 | ffmpeg-python >= 0.2.0
69 | pandas
70 | segment-anything-hq == 0.3
71 | groundingdino-py == 0.4.0
72 | gitpython
73 | rembg[gpu]
74 | onnxruntime-gpu
75 |
76 | [options.entry_points]
77 | console_scripts =
78 | animatediff = animatediff.cli:cli
79 |
80 | [flake8]
81 | max-line-length = 110
82 | ignore =
83 | # these are annoying during development but should be enabled later
84 | F401 # module imported but unused
85 | F841 # local variable is assigned to but never used
86 | # black automatically fixes this
87 | E501 # line too long
88 | # black breaks these two rules:
89 | E203 # whitespace before :
90 | W503 # line break before binary operator
91 | extend-exclude =
92 | .venv
93 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.insertSpaces": true,
3 | "editor.tabSize": 4,
4 | "files.trimTrailingWhitespace": true,
5 | "editor.rulers": [100, 120],
6 |
7 | "files.associations": {
8 | "*.yaml": "yaml"
9 | },
10 |
11 | "files.exclude": {
12 | "**/.git": true,
13 | "**/.svn": true,
14 | "**/.hg": true,
15 | "**/CVS": true,
16 | "**/.DS_Store": true,
17 | "**/Thumbs.db": true,
18 | "**/__pycache__": true
19 | },
20 |
21 | "[python]": {
22 | "editor.wordBasedSuggestions": false,
23 | "editor.formatOnSave": true,
24 | "editor.defaultFormatter": "ms-python.black-formatter",
25 | "editor.codeActionsOnSave": {
26 | "source.organizeImports": true
27 | }
28 | },
29 | "python.analysis.include": ["./src", "./scripts", "./tests"],
30 |
31 | "python.linting.enabled": false,
32 | "python.linting.pylintEnabled": false,
33 | "python.linting.flake8Enabled": true,
34 | "python.linting.flake8Args": ["--config=${workspaceFolder}/setup.cfg"],
35 |
36 | "[json]": {
37 | "editor.tabSize": 2,
38 | "editor.detectIndentation": false,
39 | "editor.formatOnSave": true,
40 | "editor.formatOnSaveMode": "file"
41 | },
42 |
43 | "[toml]": {
44 | "editor.tabSize": 2,
45 | "editor.detectIndentation": false,
46 | "editor.formatOnSave": true,
47 | "editor.formatOnSaveMode": "file",
48 | "editor.defaultFormatter": "tamasfe.even-better-toml",
49 | "editor.rulers": [80, 100]
50 | },
51 | "evenBetterToml.formatter.columnWidth": 88,
52 |
53 | "[yaml]": {
54 | "editor.detectIndentation": false,
55 | "editor.tabSize": 2,
56 | "editor.formatOnSave": true,
57 | "editor.formatOnSaveMode": "file"
58 | },
59 | "yaml.format.bracketSpacing": true,
60 | "yaml.format.proseWrap": "preserve",
61 | "yaml.format.singleQuote": false,
62 | "yaml.format.printWidth": 110,
63 |
64 | "[markdown]": {
65 | "files.trimTrailingWhitespace": false
66 | },
67 |
68 | "css.lint.validProperties": ["dock", "content-align", "content-justify"],
69 | "[css]": {
70 | "editor.formatOnSave": true
71 | },
72 |
73 | "remote.autoForwardPorts": false,
74 | "remote.autoForwardPortsSource": "process"
75 | }
76 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import json
3 | import os
4 | import asyncio
5 |
6 | async def stylize(video):
7 | command = f"animatediff stylize create-config {video}"
8 | process = await asyncio.create_subprocess_shell(
9 | command,
10 | stdout=asyncio.subprocess.PIPE,
11 | stderr=asyncio.subprocess.PIPE
12 | )
13 | stdout, stderr = await process.communicate()
14 | if process.returncode == 0:
15 | return stdout.decode()
16 | else:
17 | print(f"Error: {stderr.decode()}")
18 |
19 | async def start_video_edit(prompt_file):
20 | command = f"animatediff stylize generate {prompt_file}"
21 | process = await asyncio.create_subprocess_shell(
22 | command,
23 | stdout=asyncio.subprocess.PIPE,
24 | stderr=asyncio.subprocess.PIPE
25 | )
26 | stdout, stderr = await process.communicate()
27 | if process.returncode == 0:
28 | return stdout.decode()
29 | else:
30 | print(f"Error: {stderr.decode()}")
31 |
32 | def edit_video(video, pos_prompt):
33 | x = asyncio.run(stylize(video))
34 | x = x.split("stylize.py")
35 | config = x[18].split("config =")[-1].strip()
36 | d = x[19].split("stylize_dir = ")[-1].strip()
37 |
38 | with open(config, 'r+') as f:
39 | data = json.load(f)
40 | data['head_prompt'] = pos_prompt
41 | data["path"] = "models/huggingface/xxmix9realistic_v40.safetensors"
42 |
43 | os.remove(config)
44 | with open(config, 'w') as f:
45 | json.dump(data, f, indent=4)
46 |
47 | out = asyncio.run(start_video_edit(d))
48 | out = out.split("Stylized results are output to ")[-1]
49 | out = out.split("stylize.py")[0].strip()
50 |
51 | cwd = os.getcwd()
52 | video_dir = cwd + "/" + out
53 |
54 | video_extensions = {'.mp4', '.avi', '.mkv', '.mov', '.flv', '.wmv'}
55 | video_path = None
56 |
57 | for dirpath, dirnames, filenames in os.walk(video_dir):
58 | for filename in filenames:
59 | if os.path.splitext(filename)[1].lower() in video_extensions:
60 | video_path = os.path.join(dirpath, filename)
61 | break
62 | if video_path:
63 | break
64 |
65 | return video_path
66 |
67 | video_path = input("Enter the path to your video: ")
68 | pos_prompt = input("Enter the what you want to do with the video: ")
69 | print("The video is stored at", edit_video(video_path, pos_prompt))
--------------------------------------------------------------------------------
/src/animatediff/utils/mask_rembg.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import logging
3 | import os
4 | from pathlib import Path
5 |
6 | import cv2
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | from rembg import new_session, remove
11 | from tqdm.rich import tqdm
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | def rembg_create_fg(frame_dir, output_dir, output_mask_dir, masked_area_list,
16 | bg_color=(0,255,0),
17 | mask_padding=0,
18 | ):
19 |
20 | frame_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False))
21 |
22 | if mask_padding != 0:
23 | kernel = np.ones((abs(mask_padding),abs(mask_padding)),np.uint8)
24 | kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
25 |
26 | session = new_session(providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
27 |
28 | for i, frame in tqdm(enumerate(frame_list),total=len(frame_list), desc=f"creating mask"):
29 | frame = Path(frame)
30 | file_name = frame.name
31 |
32 | cur_frame_no = int(frame.stem)
33 |
34 | img = Image.open(frame)
35 | img_array = np.asarray(img)
36 |
37 | mask_array = remove(img_array, only_mask=True, session=session)
38 |
39 | #mask_array = mask_array[None,...]
40 |
41 | if mask_padding < 0:
42 | mask_array = cv2.erode(mask_array.astype(np.uint8),kernel,iterations = 1)
43 | elif mask_padding > 0:
44 | mask_array = cv2.dilate(mask_array.astype(np.uint8),kernel,iterations = 1)
45 |
46 | mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_OPEN, kernel2)
47 | mask_array = cv2.GaussianBlur(mask_array, (7, 7), sigmaX=3, sigmaY=3, borderType=cv2.BORDER_DEFAULT)
48 |
49 | if masked_area_list[cur_frame_no] is not None:
50 | masked_area_list[cur_frame_no] = np.where(masked_area_list[cur_frame_no] > mask_array[None,...], masked_area_list[cur_frame_no], mask_array[None,...])
51 | else:
52 | masked_area_list[cur_frame_no] = mask_array[None,...]
53 |
54 | if output_mask_dir:
55 | Image.fromarray(mask_array).save( output_mask_dir / file_name )
56 |
57 | img_array = np.asarray(img).copy()
58 | if bg_color is not None:
59 | img_array[mask_array == 0] = bg_color
60 |
61 | img = Image.fromarray(img_array)
62 |
63 | img.save( output_dir / file_name )
64 |
65 | return masked_area_list
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/config/prompts/03-RcnzCartoon.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "RcnzCartoon",
3 | "base": "",
4 | "path": "models/sd/rcnzCartoon3d_v10.safetensors",
5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt",
6 | "seed": [
7 | 16931037867122268000, 2094308009433392000, 4292543217695451000,
8 | 15572665120852310000
9 | ],
10 | "scheduler": "k_dpmpp",
11 | "steps": 25,
12 | "guidance_scale": 7.5,
13 | "prompt": [
14 | "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded",
15 | "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face",
16 | "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes",
17 | "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering"
18 | ],
19 | "n_prompt": [
20 | "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
21 | "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular",
22 | "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face,",
23 | "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand"
24 | ]
25 | }
26 |
--------------------------------------------------------------------------------
/config/prompts/06-Tusun.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tusun",
3 | "base": "models/sd/moonfilm_reality20.safetensors",
4 | "path": "models/sd/TUSUN.safetensors",
5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt",
6 | "seed": [
7 | 10154078483724687000, 2664393535095473700, 4231566096207623000,
8 | 1713349740448094500
9 | ],
10 | "scheduler": "k_dpmpp",
11 | "steps": 25,
12 | "guidance_scale": 7.5,
13 | "lora_alpha": 0.6,
14 | "prompt": [
15 | "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing",
16 | "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing",
17 | "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing",
18 | "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body"
19 | ],
20 | "n_prompt": [
21 | "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative"
22 | ]
23 | }
24 |
--------------------------------------------------------------------------------
/config/prompts/05-RealisticVision.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "RealisticVision",
3 | "base": "",
4 | "path": "models/sd/realisticVisionV20_v20.safetensors",
5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt",
6 | "seed": [
7 | 5658137986800322000, 12099779162349365000, 10499524853910854000,
8 | 16768009035333712000
9 | ],
10 | "scheduler": "k_dpmpp",
11 | "steps": 25,
12 | "guidance_scale": 7.5,
13 | "prompt": [
14 | "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
15 | "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot",
16 | "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
17 | "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
18 | ],
19 | "n_prompt": [
20 | "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
21 | "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
22 | "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation",
23 | "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
24 | ]
25 | }
26 |
--------------------------------------------------------------------------------
/config/inference/sd15-unet.json:
--------------------------------------------------------------------------------
1 | {
2 | "sample_size": 64,
3 | "in_channels": 4,
4 | "out_channels": 4,
5 | "center_input_sample": false,
6 | "flip_sin_to_cos": true,
7 | "freq_shift": 0,
8 | "down_block_types": [
9 | "CrossAttnDownBlock2D",
10 | "CrossAttnDownBlock2D",
11 | "CrossAttnDownBlock2D",
12 | "DownBlock2D"
13 | ],
14 | "mid_block_type": "UNetMidBlock2DCrossAttn",
15 | "up_block_types": [
16 | "UpBlock2D",
17 | "CrossAttnUpBlock2D",
18 | "CrossAttnUpBlock2D",
19 | "CrossAttnUpBlock2D"
20 | ],
21 | "only_cross_attention": false,
22 | "block_out_channels": [320, 640, 1280, 1280],
23 | "layers_per_block": 2,
24 | "downsample_padding": 1,
25 | "mid_block_scale_factor": 1,
26 | "act_fn": "silu",
27 | "norm_num_groups": 32,
28 | "norm_eps": 1e-5,
29 | "cross_attention_dim": 768,
30 | "transformer_layers_per_block": 1,
31 | "encoder_hid_dim": null,
32 | "encoder_hid_dim_type": null,
33 | "attention_head_dim": 8,
34 | "num_attention_heads": null,
35 | "dual_cross_attention": false,
36 | "use_linear_projection": false,
37 | "class_embed_type": null,
38 | "addition_embed_type": null,
39 | "addition_time_embed_dim": null,
40 | "num_class_embeds": null,
41 | "upcast_attention": false,
42 | "resnet_time_scale_shift": "default",
43 | "resnet_skip_time_act": false,
44 | "resnet_out_scale_factor": 1.0,
45 | "time_embedding_type": "positional",
46 | "time_embedding_dim": null,
47 | "time_embedding_act_fn": null,
48 | "timestep_post_act": null,
49 | "time_cond_proj_dim": null,
50 | "conv_in_kernel": 3,
51 | "conv_out_kernel": 3,
52 | "projection_class_embeddings_input_dim": null,
53 | "class_embeddings_concat": false,
54 | "mid_block_only_cross_attention": null,
55 | "cross_attention_norm": null,
56 | "addition_embed_type_num_heads": 64,
57 | "_use_default_values": [
58 | "transformer_layers_per_block",
59 | "use_linear_projection",
60 | "num_class_embeds",
61 | "addition_embed_type",
62 | "cross_attention_norm",
63 | "conv_out_kernel",
64 | "encoder_hid_dim_type",
65 | "projection_class_embeddings_input_dim",
66 | "num_attention_heads",
67 | "only_cross_attention",
68 | "class_embed_type",
69 | "resnet_time_scale_shift",
70 | "addition_embed_type_num_heads",
71 | "timestep_post_act",
72 | "mid_block_type",
73 | "mid_block_only_cross_attention",
74 | "time_embedding_type",
75 | "addition_time_embed_dim",
76 | "time_embedding_dim",
77 | "encoder_hid_dim",
78 | "resnet_skip_time_act",
79 | "conv_in_kernel",
80 | "upcast_attention",
81 | "dual_cross_attention",
82 | "resnet_out_scale_factor",
83 | "time_cond_proj_dim",
84 | "class_embeddings_concat",
85 | "time_embedding_act_fn"
86 | ],
87 | "_class_name": "UNet2DConditionModel",
88 | "_diffusers_version": "0.6.0"
89 | }
90 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import asyncio
4 | import gradio as gr
5 |
6 | async def stylize(video):
7 | command = f"animatediff stylize create-config {video}"
8 | process = await asyncio.create_subprocess_shell(
9 | command,
10 | stdout=asyncio.subprocess.PIPE,
11 | stderr=asyncio.subprocess.PIPE
12 | )
13 | stdout, stderr = await process.communicate()
14 | if process.returncode == 0:
15 | return stdout.decode()
16 | else:
17 | print(f"Error: {stderr.decode()}")
18 |
19 | async def start_video_edit(prompt_file):
20 | command = f"animatediff stylize generate {prompt_file}"
21 | process = await asyncio.create_subprocess_shell(
22 | command,
23 | stdout=asyncio.subprocess.PIPE,
24 | stderr=asyncio.subprocess.PIPE
25 | )
26 | stdout, stderr = await process.communicate()
27 | if process.returncode == 0:
28 | return stdout.decode()
29 | else:
30 | print(f"Error: {stderr.decode()}")
31 |
32 | def edit_video(video, pos_prompt):
33 | x = asyncio.run(stylize(video))
34 | x = x.split("stylize.py")
35 | config = x[18].split("config =")[-1].strip()
36 | d = x[19].split("stylize_dir = ")[-1].strip()
37 |
38 | with open(config, 'r+') as f:
39 | data = json.load(f)
40 | data['head_prompt'] = pos_prompt
41 | data["path"] = "models/huggingface/xxmix9realistic_v40.safetensors"
42 |
43 | os.remove(config)
44 | with open(config, 'w') as f:
45 | json.dump(data, f, indent=4)
46 |
47 | out = asyncio.run(start_video_edit(d))
48 | out = out.split("Stylized results are output to ")[-1]
49 | out = out.split("stylize.py")[0].strip()
50 |
51 | cwd = os.getcwd()
52 | video_dir = cwd + "/" + out
53 |
54 | video_extensions = {'.mp4', '.avi', '.mkv', '.mov', '.flv', '.wmv'}
55 | video_path = None
56 |
57 | for dirpath, dirnames, filenames in os.walk(video_dir):
58 | for filename in filenames:
59 | if os.path.splitext(filename)[1].lower() in video_extensions:
60 | video_path = os.path.join(dirpath, filename)
61 | break
62 | if video_path:
63 | break
64 |
65 | return video_path
66 |
67 | print("ready")
68 |
69 | with gr.Blocks() as interface:
70 | gr.Markdown("## Video Processor with Text Prompts")
71 | with gr.Row():
72 | with gr.Column():
73 | positive_prompt = gr.Textbox(label="Positive Prompt")
74 | video_input = gr.Video(label="Input Video")
75 | with gr.Column():
76 | video_output = gr.Video(label="Processed Video")
77 |
78 | process_button = gr.Button("Process Video")
79 | process_button.click(fn=edit_video,
80 | inputs=[video_input, positive_prompt],
81 | outputs=video_output
82 | )
83 |
84 | interface.launch()
85 |
--------------------------------------------------------------------------------
/src/animatediff/dwpose/__init__.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | # Openpose
3 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
4 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
5 | # 3rd Edited by ControlNet
6 | # 4th Edited by ControlNet (added face and correct hands)
7 |
8 | import os
9 |
10 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
11 | import cv2
12 | import numpy as np
13 | import torch
14 | from controlnet_aux.util import HWC3, resize_image
15 | from PIL import Image
16 |
17 | from . import util
18 | from .wholebody import Wholebody
19 |
20 |
21 | def draw_pose(pose, H, W):
22 | bodies = pose['bodies']
23 | faces = pose['faces']
24 | hands = pose['hands']
25 | candidate = bodies['candidate']
26 | subset = bodies['subset']
27 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
28 |
29 | canvas = util.draw_bodypose(canvas, candidate, subset)
30 |
31 | canvas = util.draw_handpose(canvas, hands)
32 |
33 | canvas = util.draw_facepose(canvas, faces)
34 |
35 | return canvas
36 |
37 |
38 | class DWposeDetector:
39 | def __init__(self):
40 | pass
41 |
42 | def to(self, device):
43 | self.pose_estimation = Wholebody(device)
44 | return self
45 |
46 | def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
47 | input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
48 |
49 | input_image = HWC3(input_image)
50 | input_image = resize_image(input_image, detect_resolution)
51 | H, W, C = input_image.shape
52 | with torch.no_grad():
53 | candidate, subset = self.pose_estimation(input_image)
54 | nums, keys, locs = candidate.shape
55 | candidate[..., 0] /= float(W)
56 | candidate[..., 1] /= float(H)
57 | body = candidate[:,:18].copy()
58 | body = body.reshape(nums*18, locs)
59 | score = subset[:,:18]
60 | for i in range(len(score)):
61 | for j in range(len(score[i])):
62 | if score[i][j] > 0.3:
63 | score[i][j] = int(18*i+j)
64 | else:
65 | score[i][j] = -1
66 |
67 | un_visible = subset<0.3
68 | candidate[un_visible] = -1
69 |
70 | foot = candidate[:,18:24]
71 |
72 | faces = candidate[:,24:92]
73 |
74 | hands = candidate[:,92:113]
75 | hands = np.vstack([hands, candidate[:,113:]])
76 |
77 | bodies = dict(candidate=body, subset=score)
78 | pose = dict(bodies=bodies, hands=hands, faces=faces)
79 |
80 | detected_map = draw_pose(pose, H, W)
81 | detected_map = HWC3(detected_map)
82 |
83 | img = resize_image(input_image, image_resolution)
84 | H, W, C = img.shape
85 |
86 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
87 |
88 | if output_type == "pil":
89 | detected_map = Image.fromarray(detected_map)
90 |
91 | return detected_map
92 |
--------------------------------------------------------------------------------
/config/prompts/07-FilmVelvia.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "FilmVelvia",
3 | "base": "models/sd/majicmixRealistic_v4.safetensors",
4 | "path": "models/sd/FilmVelvia2.safetensors",
5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt",
6 | "seed": [
7 | 358675358833372800, 3519455280971924000, 11684545350557985000,
8 | 8696855302100400000
9 | ],
10 | "scheduler": "k_dpmpp",
11 | "steps": 25,
12 | "guidance_scale": 7.5,
13 | "lora_alpha": 0.6,
14 | "prompt": [
15 | "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name",
16 | ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir",
17 | "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark",
18 | "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, "
19 | ],
20 | "n_prompt": [
21 | "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg",
22 | "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg",
23 | "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg",
24 | "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
25 | ]
26 | }
27 |
--------------------------------------------------------------------------------
/config/prompts/02-Lyriel.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Lyriel",
3 | "base": "",
4 | "path": "models/sd/lyriel_v16.safetensors",
5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt",
6 | "seed": [
7 | 10917152860782582000, 6399018107401806000, 15875751942533906000,
8 | 6653196880059937000
9 | ],
10 | "scheduler": "k_dpmpp",
11 | "steps": 25,
12 | "guidance_scale": 7.5,
13 | "prompt": [
14 | "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange",
15 | "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal",
16 | "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray",
17 | "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown."
18 | ],
19 | "n_prompt": [
20 | "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration",
21 | "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular",
22 | "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome",
23 | "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render"
24 | ]
25 | }
26 |
--------------------------------------------------------------------------------
/src/animatediff/utils/mask_animseg.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import logging
3 | import os
4 | from pathlib import Path
5 |
6 | import cv2
7 | import numpy as np
8 | import onnxruntime as rt
9 | import torch
10 | from PIL import Image
11 | from rembg import new_session, remove
12 | from tqdm.rich import tqdm
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | def animseg_create_fg(frame_dir, output_dir, output_mask_dir, masked_area_list,
17 | bg_color=(0,255,0),
18 | mask_padding=0,
19 | ):
20 |
21 | frame_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False))
22 |
23 | if mask_padding != 0:
24 | kernel = np.ones((abs(mask_padding),abs(mask_padding)),np.uint8)
25 | kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
26 |
27 |
28 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
29 | rmbg_model = rt.InferenceSession("data/models/anime_seg/isnetis.onnx", providers=providers)
30 |
31 | def get_mask(img, s=1024):
32 | img = (img / 255).astype(np.float32)
33 | h, w = h0, w0 = img.shape[:-1]
34 | h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
35 | ph, pw = s - h, s - w
36 | img_input = np.zeros([s, s, 3], dtype=np.float32)
37 | img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
38 | img_input = np.transpose(img_input, (2, 0, 1))
39 | img_input = img_input[np.newaxis, :]
40 | mask = rmbg_model.run(None, {'img': img_input})[0][0]
41 | mask = np.transpose(mask, (1, 2, 0))
42 | mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
43 | mask = cv2.resize(mask, (w0, h0))
44 | mask = (mask * 255).astype(np.uint8)
45 | return mask
46 |
47 |
48 | for i, frame in tqdm(enumerate(frame_list),total=len(frame_list), desc=f"creating mask"):
49 | frame = Path(frame)
50 | file_name = frame.name
51 |
52 | cur_frame_no = int(frame.stem)
53 |
54 | img = Image.open(frame)
55 | img_array = np.asarray(img)
56 |
57 | mask_array = get_mask(img_array)
58 |
59 | # Image.fromarray(mask_array).save( output_dir / Path("raw_" + file_name))
60 |
61 | if mask_padding < 0:
62 | mask_array = cv2.erode(mask_array.astype(np.uint8),kernel,iterations = 1)
63 | elif mask_padding > 0:
64 | mask_array = cv2.dilate(mask_array.astype(np.uint8),kernel,iterations = 1)
65 |
66 | mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_OPEN, kernel2)
67 | mask_array = cv2.GaussianBlur(mask_array, (7, 7), sigmaX=3, sigmaY=3, borderType=cv2.BORDER_DEFAULT)
68 |
69 | if masked_area_list[cur_frame_no] is not None:
70 | masked_area_list[cur_frame_no] = np.where(masked_area_list[cur_frame_no] > mask_array[None,...], masked_area_list[cur_frame_no], mask_array[None,...])
71 | else:
72 | masked_area_list[cur_frame_no] = mask_array[None,...]
73 |
74 | if output_mask_dir:
75 | Image.fromarray(mask_array).save( output_mask_dir / file_name )
76 |
77 | img_array = np.asarray(img).copy()
78 | if bg_color is not None:
79 | img_array[mask_array == 0] = bg_color
80 |
81 | img = Image.fromarray(img_array)
82 |
83 | img.save( output_dir / file_name )
84 |
85 | return masked_area_list
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.30.1
3 | aiofiles==23.2.1
4 | aiohttp==3.9.5
5 | aiosignal==1.3.1
6 | altair==5.3.0
7 | analytics-python==1.4.post1
8 | -e git+https://github.com/TheNetherWatcher/Vid2Vid-using-Text-prompt.git@213aa1fc330895557d619514bc778d0985a3112a#egg=animatediff
9 | annotated-types==0.7.0
10 | antlr4-python3-runtime==4.9.3
11 | anyio==4.4.0
12 | async-timeout==4.0.3
13 | attrs==23.2.0
14 | backoff==1.10.0
15 | bcrypt==4.1.3
16 | beautifulsoup4==4.12.3
17 | blinker==1.8.2
18 | certifi==2024.2.2
19 | cffi==1.16.0
20 | charset-normalizer==3.3.2
21 | click==8.1.7
22 | cmake==3.29.3
23 | colorama==0.4.6
24 | coloredlogs==15.0.1
25 | contourpy==1.2.1
26 | controlnet-aux==0.0.9
27 | cryptography==42.0.7
28 | cycler==0.12.1
29 | diffusers==0.23.0
30 | dnspython==2.6.1
31 | einops==0.8.0
32 | email_validator==2.1.1
33 | exceptiongroup==1.2.1
34 | fastapi==0.111.0
35 | fastapi-cli==0.0.4
36 | ffmpeg-python==0.2.0
37 | ffmpy==0.3.2
38 | filelock==3.14.0
39 | Flask==3.0.3
40 | Flask-CacheBuster==1.0.0
41 | Flask-Cors==4.0.1
42 | Flask-Login==0.6.3
43 | flatbuffers==24.3.25
44 | fonttools==4.53.0
45 | frozenlist==1.4.1
46 | fsspec==2024.5.0
47 | future==1.0.0
48 | gdown==5.2.0
49 | gradio==3.0
50 | gradio_client==0.7.0
51 | h11==0.14.0
52 | httpcore==1.0.5
53 | httptools==0.6.1
54 | httpx==0.27.0
55 | huggingface-hub==0.17.3
56 | humanfriendly==10.0
57 | idna==3.7
58 | imageio==2.34.1
59 | importlib_metadata==7.1.0
60 | importlib_resources==6.4.0
61 | itsdangerous==2.2.0
62 | jax==0.4.28
63 | jaxlib==0.4.28
64 | Jinja2==3.1.4
65 | jsonschema==4.22.0
66 | jsonschema-specifications==2023.12.1
67 | kiwisolver==1.4.5
68 | lazy_loader==0.4
69 | linkify-it-py==2.0.3
70 | markdown-it-py==3.0.0
71 | markdown2==2.4.13
72 | MarkupSafe==2.1.5
73 | matplotlib==3.9.0
74 | mdit-py-plugins==0.4.1
75 | mdurl==0.1.2
76 | mediapipe==0.10.14
77 | ml-dtypes==0.4.0
78 | monotonic==1.6
79 | mpmath==1.3.0
80 | multidict==6.0.5
81 | networkx==3.3
82 | ninja==1.11.1.1
83 | numpy==1.26.4
84 | omegaconf==2.3.0
85 | opencv-contrib-python==4.9.0.80
86 | opencv-python==4.9.0.80
87 | opencv-python-headless==4.9.0.80
88 | opt-einsum==3.3.0
89 | orjson==3.10.3
90 | packaging==24.0
91 | pandas==2.2.2
92 | paramiko==3.4.0
93 | Pillow==9.5.0
94 | protobuf==4.25.3
95 | psutil==5.9.8
96 | pycparser==2.22
97 | pycryptodome==3.20.0
98 | pydantic==1.10.15
99 | pydantic_core==2.18.3
100 | pydub==0.25.1
101 | Pygments==2.18.0
102 | PyNaCl==1.5.0
103 | pyparsing==3.1.2
104 | PySocks==1.7.1
105 | python-dateutil==2.9.0.post0
106 | python-dotenv==1.0.1
107 | python-multipart==0.0.9
108 | pytz==2024.1
109 | PyYAML==6.0.1
110 | referencing==0.35.1
111 | regex==2024.5.15
112 | requests==2.32.3
113 | rich==13.7.1
114 | rpds-py==0.18.1
115 | ruff==0.4.7
116 | safetensors==0.4.3
117 | scikit-image==0.23.2
118 | scipy==1.13.1
119 | semantic-version==2.10.0
120 | sentencepiece==0.2.0
121 | shellingham==1.5.4
122 | six==1.16.0
123 | sniffio==1.3.1
124 | sounddevice==0.4.7
125 | soupsieve==2.5
126 | starlette==0.37.2
127 | sympy==1.12.1
128 | tifffile==2024.5.22
129 | timm==0.6.7
130 | tokenizers==0.14.1
131 | tomlkit==0.12.0
132 | toolz==0.12.1
133 | torch==2.1.2
134 | torchaudio==2.1.2
135 | torchvision==0.16.2
136 | tqdm==4.66.4
137 | transformers==4.34.1
138 | triton==2.1.0
139 | typer==0.12.3
140 | typing_extensions==4.12.0
141 | tzdata==2024.1
142 | uc-micro-py==1.0.3
143 | ujson==5.10.0
144 | urllib3==2.2.1
145 | uvicorn==0.30.1
146 | uvloop==0.19.0
147 | watchfiles==0.22.0
148 | websockets==11.0.3
149 | Werkzeug==3.0.3
150 | xformers==0.0.23.post1
151 | yarl==1.9.4
152 | zipp==3.19.1
--------------------------------------------------------------------------------
/src/animatediff/pipelines/context.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional
2 |
3 | import numpy as np
4 |
5 |
6 | # Whatever this is, it's utterly cursed.
7 | def ordered_halving(val):
8 | bin_str = f"{val:064b}"
9 | bin_flip = bin_str[::-1]
10 | as_int = int(bin_flip, 2)
11 |
12 | return as_int / (1 << 64)
13 |
14 |
15 | # I have absolutely no idea how this works and I don't like that.
16 | def uniform(
17 | step: int = ...,
18 | num_steps: Optional[int] = None,
19 | num_frames: int = ...,
20 | context_size: Optional[int] = None,
21 | context_stride: int = 3,
22 | context_overlap: int = 4,
23 | closed_loop: bool = True,
24 | ):
25 | if num_frames <= context_size:
26 | yield list(range(num_frames))
27 | return
28 |
29 | context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
30 |
31 | for context_step in 1 << np.arange(context_stride):
32 | pad = int(round(num_frames * ordered_halving(step)))
33 | for j in range(
34 | int(ordered_halving(step) * context_step) + pad,
35 | num_frames + pad + (0 if closed_loop else -context_overlap),
36 | (context_size * context_step - context_overlap),
37 | ):
38 | yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
39 |
40 |
41 | def shuffle(
42 | step: int = ...,
43 | num_steps: Optional[int] = None,
44 | num_frames: int = ...,
45 | context_size: Optional[int] = None,
46 | context_stride: int = 3,
47 | context_overlap: int = 4,
48 | closed_loop: bool = True,
49 | ):
50 | import random
51 | c = list(range(num_frames))
52 | c = random.sample(c, len(c))
53 |
54 | if len(c) % context_size:
55 | c += c[0:context_size - len(c) % context_size]
56 |
57 | c = random.sample(c, len(c))
58 |
59 | for i in range(0, len(c), context_size):
60 | yield c[i:i+context_size]
61 |
62 |
63 | def composite(
64 | step: int = ...,
65 | num_steps: Optional[int] = None,
66 | num_frames: int = ...,
67 | context_size: Optional[int] = None,
68 | context_stride: int = 3,
69 | context_overlap: int = 4,
70 | closed_loop: bool = True,
71 | ):
72 | if (step/num_steps) < 0.1:
73 | return shuffle(step,num_steps,num_frames,context_size,context_stride,context_overlap,closed_loop)
74 | else:
75 | return uniform(step,num_steps,num_frames,context_size,context_stride,context_overlap,closed_loop)
76 |
77 |
78 | def get_context_scheduler(name: str) -> Callable:
79 | match name:
80 | case "uniform":
81 | return uniform
82 | case "shuffle":
83 | return shuffle
84 | case "composite":
85 | return composite
86 | case _:
87 | raise ValueError(f"Unknown context_overlap policy {name}")
88 |
89 |
90 | def get_total_steps(
91 | scheduler,
92 | timesteps: list[int],
93 | num_steps: Optional[int] = None,
94 | num_frames: int = ...,
95 | context_size: Optional[int] = None,
96 | context_stride: int = 3,
97 | context_overlap: int = 4,
98 | closed_loop: bool = True,
99 | ):
100 | return sum(
101 | len(
102 | list(
103 | scheduler(
104 | i,
105 | num_steps,
106 | num_frames,
107 | context_size,
108 | context_stride,
109 | context_overlap,
110 | )
111 | )
112 | )
113 | for i in range(len(timesteps))
114 | )
115 |
--------------------------------------------------------------------------------
/src/animatediff/rife/ncnn.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from typing import Optional
4 |
5 | from pydantic import BaseModel, Field
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class RifeNCNNOptions(BaseModel):
11 | model_path: Path = Field(..., description="Path to RIFE model directory")
12 | input_path: Path = Field(..., description="Path to source frames directory")
13 | output_path: Optional[Path] = Field(None, description="Path to output frames directory")
14 | num_frame: Optional[int] = Field(None, description="Number of frames to generate (default N*2)")
15 | time_step: float = Field(0.5, description="Time step for interpolation (default 0.5)", gt=0.0, le=1.0)
16 | gpu_id: Optional[int | list[int]] = Field(
17 | None, description="GPU ID(s) to use (default: auto, -1 for CPU)"
18 | )
19 | load_threads: int = Field(1, description="Number of threads for frame loading", gt=0)
20 | process_threads: int = Field(2, description="Number of threads used for frame processing", gt=0)
21 | save_threads: int = Field(2, description="Number of threads for frame saving", gt=0)
22 | spatial_tta: bool = Field(False, description="Enable spatial TTA mode")
23 | temporal_tta: bool = Field(False, description="Enable temporal TTA mode")
24 | uhd: bool = Field(False, description="Enable UHD mode")
25 | verbose: bool = Field(False, description="Enable verbose logging")
26 |
27 | def get_args(self, frame_multiplier: int = 7) -> list[str]:
28 | """Generate arguments to pass to rife-ncnn-vulkan.
29 |
30 | Frame multiplier is used to calculate the number of frames to generate, if num_frame is not set.
31 | """
32 | if self.output_path is None:
33 | self.output_path = self.input_path.joinpath("out")
34 |
35 | # calc num frames
36 | if self.num_frame is None:
37 | num_src_frames = len([x for x in self.input_path.glob("*.png") if x.is_file()])
38 | logger.info(f"Found {num_src_frames} source frames, using multiplier {frame_multiplier}")
39 | num_frame = num_src_frames * frame_multiplier
40 | logger.info(f"We will generate {num_frame} frames")
41 | else:
42 | num_frame = self.num_frame
43 |
44 | # GPU ID and process threads are comma-separated lists, so we need to convert them to strings
45 | if self.gpu_id is None:
46 | gpu_id = "auto"
47 | process_threads = self.process_threads
48 | elif isinstance(self.gpu_id, list):
49 | gpu_id = ",".join([str(x) for x in self.gpu_id])
50 | process_threads = ",".join([str(self.process_threads) for _ in self.gpu_id])
51 | else:
52 | gpu_id = str(self.gpu_id)
53 | process_threads = str(self.process_threads)
54 |
55 | # Build args list
56 | args_list = [
57 | "-i",
58 | f"{self.input_path.resolve()}/",
59 | "-o",
60 | f"{self.output_path.resolve()}/",
61 | "-m",
62 | f"{self.model_path.resolve()}/",
63 | "-n",
64 | num_frame,
65 | "-s",
66 | f"{self.time_step:.5f}",
67 | "-g",
68 | gpu_id,
69 | "-j",
70 | f"{self.load_threads}:{process_threads}:{self.save_threads}",
71 | ]
72 |
73 | # Add flags if set
74 | if self.spatial_tta:
75 | args_list.append("-x")
76 | if self.temporal_tta:
77 | args_list.append("-z")
78 | if self.uhd:
79 | args_list.append("-u")
80 | if self.verbose:
81 | args_list.append("-v")
82 |
83 | # Convert all args to strings and return
84 | return [str(x) for x in args_list]
85 |
--------------------------------------------------------------------------------
/src/animatediff/schedulers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from enum import Enum
3 |
4 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
5 | DPMSolverSinglestepScheduler,
6 | EulerAncestralDiscreteScheduler,
7 | EulerDiscreteScheduler,
8 | HeunDiscreteScheduler,
9 | KDPM2AncestralDiscreteScheduler,
10 | KDPM2DiscreteScheduler, LCMScheduler,
11 | LMSDiscreteScheduler, PNDMScheduler,
12 | UniPCMultistepScheduler)
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | # See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
18 | class DiffusionScheduler(str, Enum):
19 | lcm = "lcm" # LCM
20 | ddim = "ddim" # DDIM
21 | pndm = "pndm" # PNDM
22 | heun = "heun" # Heun
23 | unipc = "unipc" # UniPC
24 | euler = "euler" # Euler
25 | euler_a = "euler_a" # Euler a
26 |
27 | lms = "lms" # LMS
28 | k_lms = "k_lms" # LMS Karras
29 |
30 | dpm_2 = "dpm_2" # DPM2
31 | k_dpm_2 = "k_dpm_2" # DPM2 Karras
32 |
33 | dpm_2_a = "dpm_2_a" # DPM2 a
34 | k_dpm_2_a = "k_dpm_2_a" # DPM2 a Karras
35 |
36 | dpmpp_2m = "dpmpp_2m" # DPM++ 2M
37 | k_dpmpp_2m = "k_dpmpp_2m" # DPM++ 2M Karras
38 |
39 | dpmpp_sde = "dpmpp_sde" # DPM++ SDE
40 | k_dpmpp_sde = "k_dpmpp_sde" # DPM++ SDE Karras
41 |
42 | dpmpp_2m_sde = "dpmpp_2m_sde" # DPM++ 2M SDE
43 | k_dpmpp_2m_sde = "k_dpmpp_2m_sde" # DPM++ 2M SDE Karras
44 |
45 |
46 | def get_scheduler(name: str, config: dict = {}):
47 | is_karras = name.startswith("k_")
48 | if is_karras:
49 | # strip the k_ prefix and add the karras sigma flag to config
50 | name = name.lstrip("k_")
51 | config["use_karras_sigmas"] = True
52 |
53 | match name:
54 | case DiffusionScheduler.lcm:
55 | sched_class = LCMScheduler
56 | case DiffusionScheduler.ddim:
57 | sched_class = DDIMScheduler
58 | case DiffusionScheduler.pndm:
59 | sched_class = PNDMScheduler
60 | case DiffusionScheduler.heun:
61 | sched_class = HeunDiscreteScheduler
62 | case DiffusionScheduler.unipc:
63 | sched_class = UniPCMultistepScheduler
64 | case DiffusionScheduler.euler:
65 | sched_class = EulerDiscreteScheduler
66 | case DiffusionScheduler.euler_a:
67 | sched_class = EulerAncestralDiscreteScheduler
68 | case DiffusionScheduler.lms:
69 | sched_class = LMSDiscreteScheduler
70 | case DiffusionScheduler.dpm_2:
71 | # Equivalent to DPM2 in K-Diffusion
72 | sched_class = KDPM2DiscreteScheduler
73 | case DiffusionScheduler.dpm_2_a:
74 | # Equivalent to `DPM2 a`` in K-Diffusion
75 | sched_class = KDPM2AncestralDiscreteScheduler
76 | case DiffusionScheduler.dpmpp_2m:
77 | # Equivalent to `DPM++ 2M` in K-Diffusion
78 | sched_class = DPMSolverMultistepScheduler
79 | config["algorithm_type"] = "dpmsolver++"
80 | config["solver_order"] = 2
81 | case DiffusionScheduler.dpmpp_sde:
82 | # Equivalent to `DPM++ SDE` in K-Diffusion
83 | sched_class = DPMSolverSinglestepScheduler
84 | case DiffusionScheduler.dpmpp_2m_sde:
85 | # Equivalent to `DPM++ 2M SDE` in K-Diffusion
86 | sched_class = DPMSolverMultistepScheduler
87 | config["algorithm_type"] = "sde-dpmsolver++"
88 | case _:
89 | raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'")
90 |
91 | return sched_class.from_config(config)
92 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Automatic TikTok generator
2 |
3 | This repository contains a pipeline for video-to-video generation using text prompts. The system leverages AnimateDiff and OpenPose ControlNet for pose estimation, and incorporates a prompt traveling method for improved coherence between the original and generated videos. Users can interact with this pipeline through a Gradio app or a standard Python program.
4 |
5 | ## Techniques used
6 |
7 | - **AnimateDiff**: Utilized for generating high-quality animations based on text prompts and an image as an input.
8 | - **OpenPose ControlNet**: Used for accurate pose estimation to guide the animation process.
9 | - **Prompt Traveling Method**: Ensures better relativeness and coherence between the input video and the generated output.
10 | - **User Interfaces**:
11 | - **Gradio App**: An intuitive web-based interface for easy interaction.
12 | - **Python Program**: A script-based interface for users preferring command-line interaction.
13 |
14 | ### Base models
15 |
16 | - [XXMix_9realistic](https://civitai.com/models/47274): Model used for generating life-like video (Recommended for life-like video)
17 | - [Mistoon_Anime](https://civitai.com/models/24149/mistoonanime): Model used for generating anime-like video (Recommended for anime-like video)
18 |
19 | ### Motion modules
20 |
21 | - [mm_sd_v15_v2](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt): Motion module used for generating segments of the final from the generated images (Recommended)
22 | - [mm_sd_v15](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15.ckpt) and [mm_sd_v14](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v14.ckpt) are some other modules that can be also used.
23 |
24 | ### ControlNets
25 |
26 | - [control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_openpose.pth): ControlNet for pose estimation from the given video
27 | - Upcoming support for depth and canny controlnets too for better generated video quality.
28 |
29 | ### Prompt Travelling
30 |
31 | This is a technique that is used to give the model, instruction at which frame what to do with the output image.
32 | For example, if in the prompt body it is written like, 30 - face: up, camera: zoomed out, right-hand: waving, then in the output 30th frame, the image will be generated according to the given prompt.
33 |
34 | ## Installation
35 |
36 | To set up the environment and install the necessary dependencies, follow these steps:
37 |
38 | 1. **Clone the repository:**
39 |
40 | ```bash
41 | git clone https://github.com/TheNetherWatcher/Vid2Vid-using-Text-prompt.git
42 | cd Vid2Vid-using-Text-prompt
43 | ```
44 |
45 | 2. **Create and activate a virtual environment:**
46 |
47 | ```bash
48 | python -m venv venv
49 | source venv/bin/activate # On Windows, use `venv\Scripts\activate`
50 | ```
51 |
52 | 3. **Install the required packages:**
53 |
54 | ```bash
55 | pip install -e .
56 | pip install -e .[stylize]
57 | ```
58 |
59 | ## Usage
60 |
61 | ### Model weights
62 |
63 | - Download the model weights from the abve links or another, and put them [here](./data/models/huggingface), and for the downloaded motion modules, put them [here](data/models/motion-module)
64 | - For the first time, you might get errors like model weights not found, just go to stylize directory and in the most recently created folder, edit the model name in the prompt.json file. Support for this is also under development.
65 |
66 | ### Gradio App
67 |
68 | To run the Gradio app, execute the following command:
69 |
70 | ```bash
71 | python app.py
72 | ```
73 |
74 | The gradio app provides a interface for uploading video and providing a text prompt as a input and outputs the generated video.
75 |
76 | ### Commandline
77 |
78 | ```bash
79 | python test.py
80 | ```
81 |
82 | After running this, you will be prompted to enter the location of the video, positive prompt (the changes that you want to make in the video), and a negative prompt.
83 | Negative prompt is set to a default value, but you can edit it if you like.
84 |
85 | ## Upcoming Dedvelopments
86 |
87 | - LoRA support, and controlnet(like canny, depth, edge) support
88 | - Gradio app support for using different controlnets and LoRAs
89 | - CLI options for controlling the execution in different system
90 |
91 | ## Credits
92 |
93 | - [AnimateDiff](https://github.com/guoyww/AnimateDiff)
94 | - [Prompt Travelling using AnimateDiff](https://github.com/s9roll7/animatediff-cli-prompt-travel)
95 |
--------------------------------------------------------------------------------
/src/animatediff/utils/civitai2config.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import logging
4 | import os
5 | import re
6 | import shutil
7 | from pathlib import Path
8 |
9 | from animatediff import get_dir
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | data_dir = get_dir("data")
14 |
15 | extra_loading_regex = r'(<[^>]+?>)'
16 |
17 | def generate_config_from_civitai_info(
18 | lora_dir:Path,
19 | config_org:Path,
20 | out_dir:Path,
21 | lora_weight:float,
22 | ):
23 | lora_abs_dir = lora_dir.absolute()
24 | config_org = config_org.absolute()
25 | out_dir = out_dir.absolute()
26 |
27 | civitais = sorted(glob.glob( os.path.join(lora_abs_dir, "*.civitai.info"), recursive=False))
28 |
29 | with open(config_org, "r") as cf:
30 | org_config = json.load(cf)
31 |
32 | for civ in civitais:
33 |
34 | logger.info(f"convert {civ}")
35 |
36 | with open(civ, "r") as f:
37 | # trim .civitai.info
38 | name = os.path.splitext(os.path.splitext(os.path.basename(civ))[0])[0]
39 |
40 | output_path = out_dir.joinpath(name + ".json")
41 |
42 | if os.path.isfile(output_path):
43 | logger.info("already converted -> skip")
44 | continue
45 |
46 | if os.path.isfile( lora_abs_dir.joinpath(name + ".safetensors")):
47 | lora_path = os.path.relpath(lora_abs_dir.joinpath(name + ".safetensors"), data_dir)
48 | elif os.path.isfile( lora_abs_dir.joinpath(name + ".ckpt")):
49 | lora_path = os.path.relpath(lora_abs_dir.joinpath(name + ".ckpt"), data_dir)
50 | else:
51 | logger.info("lora file not found -> skip")
52 | continue
53 |
54 | info = json.load(f)
55 |
56 | if not info:
57 | logger.info(f"empty civitai info -> skip")
58 | continue
59 |
60 | if info["model"]["type"] not in ("LORA","lora"):
61 | logger.info(f"unsupported type {info['model']['type']} -> skip")
62 | continue
63 |
64 | new_config = org_config.copy()
65 |
66 | new_config["name"] = name
67 |
68 | new_prompt_map = {}
69 | new_n_prompt = ""
70 | new_seed = -1
71 |
72 |
73 | raw_prompt_map = {}
74 |
75 | i = 0
76 | for img_info in info["images"]:
77 | if img_info["meta"]:
78 | try:
79 | raw_prompt = img_info["meta"]["prompt"]
80 | except Exception as e:
81 | logger.info("missing prompt")
82 | continue
83 |
84 | raw_prompt_map[str(10000 + i*32)] = raw_prompt
85 |
86 | new_prompt_map[str(i*32)] = re.sub(extra_loading_regex, '', raw_prompt)
87 |
88 | if not new_n_prompt:
89 | try:
90 | new_n_prompt = img_info["meta"]["negativePrompt"]
91 | except Exception as e:
92 | new_n_prompt = ""
93 | if new_seed == -1:
94 | try:
95 | new_seed = img_info["meta"]["seed"]
96 | except Exception as e:
97 | new_seed = -1
98 |
99 | i += 1
100 |
101 | if not new_prompt_map:
102 | new_prompt_map[str(0)] = ""
103 |
104 | for k in raw_prompt_map:
105 | # comment
106 | new_prompt_map[k] = raw_prompt_map[k]
107 |
108 | new_config["prompt_map"] = new_prompt_map
109 | new_config["n_prompt"] = [new_n_prompt]
110 | new_config["seed"] = [new_seed]
111 |
112 | new_config["lora_map"] = {lora_path.replace(os.sep,'/'):lora_weight}
113 |
114 | with open( out_dir.joinpath(name + ".json"), 'w') as wf:
115 | json.dump(new_config, wf, indent=4)
116 | logger.info("converted!")
117 |
118 | preview = lora_abs_dir.joinpath(name + ".preview.png")
119 | if preview.is_file():
120 | shutil.copy(preview, out_dir.joinpath(name + ".preview.png"))
121 |
122 |
123 |
--------------------------------------------------------------------------------
/src/animatediff/utils/device.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import lru_cache
3 | from math import ceil
4 | from typing import Union
5 |
6 | import torch
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def device_info_str(device: torch.device) -> str:
12 | device_info = torch.cuda.get_device_properties(device)
13 | return (
14 | f"{device_info.name} {ceil(device_info.total_memory / 1024 ** 3)}GB, "
15 | + f"CC {device_info.major}.{device_info.minor}, {device_info.multi_processor_count} SM(s)"
16 | )
17 |
18 |
19 | @lru_cache(maxsize=4)
20 | def supports_bfloat16(device: Union[str, torch.device]) -> bool:
21 | """A non-exhaustive check for bfloat16 support on a given device.
22 | Weird that torch doesn't have a global function for this. If your device
23 | does support bfloat16 and it's not listed here, go ahead and add it.
24 | """
25 | device = torch.device(device) # make sure device is a torch.device
26 | match device.type:
27 | case "cpu":
28 | ret = False
29 | case "cuda":
30 | with device:
31 | ret = torch.cuda.is_bf16_supported()
32 | case "xla":
33 | ret = True
34 | case "mps":
35 | ret = True
36 | case _:
37 | ret = False
38 | return ret
39 |
40 |
41 | @lru_cache(maxsize=4)
42 | def maybe_bfloat16(
43 | device: Union[str, torch.device],
44 | fallback: torch.dtype = torch.float32,
45 | ) -> torch.dtype:
46 | """Returns torch.bfloat16 if available, otherwise the fallback dtype (default float32)"""
47 | device = torch.device(device) # make sure device is a torch.device
48 | return torch.bfloat16 if supports_bfloat16(device) else fallback
49 |
50 |
51 | def dtype_for_model(model: str, device: torch.device) -> torch.dtype:
52 | match model:
53 | case "unet":
54 | return torch.float32 if device.type == "cpu" else torch.float16
55 | case "tenc":
56 | return torch.float32 if device.type == "cpu" else torch.float16
57 | case "vae":
58 | return maybe_bfloat16(device, fallback=torch.float32)
59 | case unknown:
60 | raise ValueError(f"Invalid model {unknown}")
61 |
62 |
63 | def get_model_dtypes(
64 | device: Union[str, torch.device],
65 | force_half_vae: bool = False,
66 | ) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
67 | device = torch.device(device) # make sure device is a torch.device
68 | unet_dtype = dtype_for_model("unet", device)
69 | tenc_dtype = dtype_for_model("tenc", device)
70 | vae_dtype = dtype_for_model("vae", device)
71 |
72 | if device.type == "cpu":
73 | logger.warn("Device explicitly set to CPU, will run everything in fp32")
74 | logger.warn("This is likely to be *incredibly* slow, but I don't tell you how to live.")
75 |
76 | if force_half_vae:
77 | if device.type == "cpu":
78 | logger.critical("Can't force VAE to fp16 mode on CPU! Exiting...")
79 | raise RuntimeError("Can't force VAE to fp16 mode on CPU!")
80 | if vae_dtype == torch.bfloat16:
81 | logger.warn("Forcing VAE to use fp16 despite bfloat16 support! This is a bad idea!")
82 | logger.warn("If you're not sure why you're doing this, you probably shouldn't be.")
83 | vae_dtype = torch.float16
84 | else:
85 | logger.warn("Forcing VAE to use fp16 instead of fp32 on CUDA! This may result in black outputs!")
86 | logger.warn("Running a VAE in fp16 can result in black images or poor output quality.")
87 | logger.warn("I don't tell you how to live, but you probably shouldn't do this.")
88 | vae_dtype = torch.float16
89 |
90 | logger.info(f"Selected data types: {unet_dtype=}, {tenc_dtype=}, {vae_dtype=}")
91 | return unet_dtype, tenc_dtype, vae_dtype
92 |
93 |
94 | def get_memory_format(device: Union[str, torch.device]) -> torch.memory_format:
95 | device = torch.device(device) # make sure device is a torch.device
96 | # if we have a cuda device
97 | if device.type == "cuda":
98 | device_info = torch.cuda.get_device_properties(device)
99 | # Volta and newer seem to like channels_last. This will probably bite me on TU11x cards.
100 | if device_info.major >= 7:
101 | ret = torch.channels_last
102 | else:
103 | ret = torch.contiguous_format
104 | elif device.type == "xpu":
105 | # Intel ARC GPUs/XPUs like channels_last
106 | ret = torch.channels_last
107 | else:
108 | # TODO: Does MPS like channels_last? do other devices?
109 | ret = torch.contiguous_format
110 | if ret == torch.channels_last:
111 | logger.info("Using channels_last memory format for UNet and VAE")
112 | return ret
--------------------------------------------------------------------------------
/src/animatediff/utils/pipeline.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional
3 |
4 | import torch
5 | import torch._dynamo as dynamo
6 | from diffusers import (DiffusionPipeline, StableDiffusionPipeline,
7 | StableDiffusionXLPipeline)
8 | from einops._torch_specific import allow_ops_in_compiled_graph
9 |
10 | from animatediff.utils.device import get_memory_format, get_model_dtypes
11 | from animatediff.utils.model import nop_train
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def send_to_device(
17 | pipeline: DiffusionPipeline,
18 | device: torch.device,
19 | freeze: bool = True,
20 | force_half: bool = False,
21 | compile: bool = False,
22 | is_sdxl: bool = False,
23 | ) -> DiffusionPipeline:
24 | if is_sdxl:
25 | return send_to_device_sdxl(
26 | pipeline=pipeline,
27 | device=device,
28 | freeze=freeze,
29 | force_half=force_half,
30 | compile=compile,
31 | )
32 |
33 | logger.info(f"Sending pipeline to device \"{device.type}{device.index if device.index else ''}\"")
34 |
35 | unet_dtype, tenc_dtype, vae_dtype = get_model_dtypes(device, force_half)
36 | model_memory_format = get_memory_format(device)
37 |
38 | if hasattr(pipeline, 'controlnet'):
39 | unet_dtype = tenc_dtype = vae_dtype
40 |
41 | logger.info(f"-> Selected data types: {unet_dtype=},{tenc_dtype=},{vae_dtype=}")
42 |
43 | if hasattr(pipeline.controlnet, 'nets'):
44 | for i in range(len(pipeline.controlnet.nets)):
45 | pipeline.controlnet.nets[i] = pipeline.controlnet.nets[i].to(device=device, dtype=vae_dtype, memory_format=model_memory_format)
46 | else:
47 | if pipeline.controlnet:
48 | pipeline.controlnet = pipeline.controlnet.to(device=device, dtype=vae_dtype, memory_format=model_memory_format)
49 |
50 | if hasattr(pipeline, 'controlnet_map'):
51 | if pipeline.controlnet_map:
52 | for c in pipeline.controlnet_map:
53 | #pipeline.controlnet_map[c] = pipeline.controlnet_map[c].to(device=device, dtype=unet_dtype, memory_format=model_memory_format)
54 | pipeline.controlnet_map[c] = pipeline.controlnet_map[c].to(dtype=unet_dtype, memory_format=model_memory_format)
55 |
56 | if hasattr(pipeline, 'lora_map'):
57 | if pipeline.lora_map:
58 | pipeline.lora_map.to(device=device, dtype=unet_dtype)
59 |
60 | if hasattr(pipeline, 'lcm'):
61 | if pipeline.lcm:
62 | pipeline.lcm.to(device=device, dtype=unet_dtype)
63 |
64 | pipeline.unet = pipeline.unet.to(device=device, dtype=unet_dtype, memory_format=model_memory_format)
65 | pipeline.text_encoder = pipeline.text_encoder.to(device=device, dtype=tenc_dtype)
66 | pipeline.vae = pipeline.vae.to(device=device, dtype=vae_dtype, memory_format=model_memory_format)
67 |
68 | # Compile model if enabled
69 | if compile:
70 | if not isinstance(pipeline.unet, dynamo.OptimizedModule):
71 | allow_ops_in_compiled_graph() # make einops behave
72 | logger.warn("Enabling model compilation with TorchDynamo, this may take a while...")
73 | logger.warn("Model compilation is experimental and may not work as expected!")
74 | pipeline.unet = torch.compile(
75 | pipeline.unet,
76 | backend="inductor",
77 | mode="reduce-overhead",
78 | )
79 | else:
80 | logger.debug("Skipping model compilation, already compiled!")
81 |
82 | return pipeline
83 |
84 |
85 | def send_to_device_sdxl(
86 | pipeline: StableDiffusionXLPipeline,
87 | device: torch.device,
88 | freeze: bool = True,
89 | force_half: bool = False,
90 | compile: bool = False,
91 | ) -> StableDiffusionXLPipeline:
92 | logger.info(f"Sending pipeline to device \"{device.type}{device.index if device.index else ''}\"")
93 |
94 | pipeline.unet = pipeline.unet.half()
95 | pipeline.text_encoder = pipeline.text_encoder.half()
96 | pipeline.text_encoder_2 = pipeline.text_encoder_2.half()
97 |
98 | if False:
99 | pipeline.to(device)
100 | else:
101 | pipeline.enable_model_cpu_offload()
102 |
103 | pipeline.enable_xformers_memory_efficient_attention()
104 | pipeline.enable_vae_slicing()
105 | pipeline.enable_vae_tiling()
106 |
107 | return pipeline
108 |
109 |
110 |
111 | def get_context_params(
112 | length: int,
113 | context: Optional[int] = None,
114 | overlap: Optional[int] = None,
115 | stride: Optional[int] = None,
116 | ):
117 | if context is None:
118 | context = min(length, 16)
119 | if overlap is None:
120 | overlap = context // 4
121 | if stride is None:
122 | stride = 0
123 | return context, overlap, stride
124 |
--------------------------------------------------------------------------------
/src/animatediff/dwpose/onnxdet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | import cv2
3 | import numpy as np
4 | import onnxruntime
5 |
6 |
7 | def nms(boxes, scores, nms_thr):
8 | """Single class NMS implemented in Numpy."""
9 | x1 = boxes[:, 0]
10 | y1 = boxes[:, 1]
11 | x2 = boxes[:, 2]
12 | y2 = boxes[:, 3]
13 |
14 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
15 | order = scores.argsort()[::-1]
16 |
17 | keep = []
18 | while order.size > 0:
19 | i = order[0]
20 | keep.append(i)
21 | xx1 = np.maximum(x1[i], x1[order[1:]])
22 | yy1 = np.maximum(y1[i], y1[order[1:]])
23 | xx2 = np.minimum(x2[i], x2[order[1:]])
24 | yy2 = np.minimum(y2[i], y2[order[1:]])
25 |
26 | w = np.maximum(0.0, xx2 - xx1 + 1)
27 | h = np.maximum(0.0, yy2 - yy1 + 1)
28 | inter = w * h
29 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
30 |
31 | inds = np.where(ovr <= nms_thr)[0]
32 | order = order[inds + 1]
33 |
34 | return keep
35 |
36 | def multiclass_nms(boxes, scores, nms_thr, score_thr):
37 | """Multiclass NMS implemented in Numpy. Class-aware version."""
38 | final_dets = []
39 | num_classes = scores.shape[1]
40 | for cls_ind in range(num_classes):
41 | cls_scores = scores[:, cls_ind]
42 | valid_score_mask = cls_scores > score_thr
43 | if valid_score_mask.sum() == 0:
44 | continue
45 | else:
46 | valid_scores = cls_scores[valid_score_mask]
47 | valid_boxes = boxes[valid_score_mask]
48 | keep = nms(valid_boxes, valid_scores, nms_thr)
49 | if len(keep) > 0:
50 | cls_inds = np.ones((len(keep), 1)) * cls_ind
51 | dets = np.concatenate(
52 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
53 | )
54 | final_dets.append(dets)
55 | if len(final_dets) == 0:
56 | return None
57 | return np.concatenate(final_dets, 0)
58 |
59 | def demo_postprocess(outputs, img_size, p6=False):
60 | grids = []
61 | expanded_strides = []
62 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
63 |
64 | hsizes = [img_size[0] // stride for stride in strides]
65 | wsizes = [img_size[1] // stride for stride in strides]
66 |
67 | for hsize, wsize, stride in zip(hsizes, wsizes, strides):
68 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
69 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
70 | grids.append(grid)
71 | shape = grid.shape[:2]
72 | expanded_strides.append(np.full((*shape, 1), stride))
73 |
74 | grids = np.concatenate(grids, 1)
75 | expanded_strides = np.concatenate(expanded_strides, 1)
76 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
77 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
78 |
79 | return outputs
80 |
81 | def preprocess(img, input_size, swap=(2, 0, 1)):
82 | if len(img.shape) == 3:
83 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
84 | else:
85 | padded_img = np.ones(input_size, dtype=np.uint8) * 114
86 |
87 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
88 | resized_img = cv2.resize(
89 | img,
90 | (int(img.shape[1] * r), int(img.shape[0] * r)),
91 | interpolation=cv2.INTER_LINEAR,
92 | ).astype(np.uint8)
93 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
94 |
95 | padded_img = padded_img.transpose(swap)
96 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
97 | return padded_img, r
98 |
99 | def inference_detector(session, oriImg):
100 | input_shape = (640,640)
101 | img, ratio = preprocess(oriImg, input_shape)
102 |
103 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
104 | output = session.run(None, ort_inputs)
105 | predictions = demo_postprocess(output[0], input_shape)[0]
106 |
107 | boxes = predictions[:, :4]
108 | scores = predictions[:, 4:5] * predictions[:, 5:]
109 |
110 | boxes_xyxy = np.ones_like(boxes)
111 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
112 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
113 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
114 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
115 | boxes_xyxy /= ratio
116 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
117 | if dets is not None:
118 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
119 | isscore = final_scores>0.3
120 | iscat = final_cls_inds == 0
121 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
122 | final_boxes = final_boxes[isbbox]
123 | else:
124 | return []
125 |
126 | return final_boxes
127 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
3 |
4 | ### Linux ###
5 | *~
6 |
7 | # temporary files which can be created if a process still has a handle open of a deleted file
8 | .fuse_hidden*
9 |
10 | # KDE directory preferences
11 | .directory
12 |
13 | # Linux trash folder which might appear on any partition or disk
14 | .Trash-*
15 |
16 | # .nfs files are created when an open file is removed but is still being accessed
17 | .nfs*
18 |
19 | ### macOS ###
20 | # General
21 | .DS_Store
22 | .AppleDouble
23 | .LSOverride
24 |
25 | # Icon must end with two \r
26 | Icon
27 |
28 |
29 | # Thumbnails
30 | ._*
31 |
32 | # Files that might appear in the root of a volume
33 | .DocumentRevisions-V100
34 | .fseventsd
35 | .Spotlight-V100
36 | .TemporaryItems
37 | .Trashes
38 | .VolumeIcon.icns
39 | .com.apple.timemachine.donotpresent
40 |
41 | # Directories potentially created on remote AFP share
42 | .AppleDB
43 | .AppleDesktop
44 | Network Trash Folder
45 | Temporary Items
46 | .apdisk
47 |
48 | ### Python ###
49 | # Byte-compiled / optimized / DLL files
50 | __pycache__/
51 | *.py[cod]
52 | *$py.class
53 |
54 | # C extensions
55 | *.so
56 |
57 | # Distribution / packaging
58 | .Python
59 | build/
60 | develop-eggs/
61 | dist/
62 | downloads/
63 | eggs/
64 | .eggs/
65 | lib/
66 | lib64/
67 | parts/
68 | sdist/
69 | var/
70 | wheels/
71 | share/python-wheels/
72 | *.egg-info/
73 | .installed.cfg
74 | *.egg
75 | MANIFEST
76 |
77 | # PyInstaller
78 | # Usually these files are written by a python script from a template
79 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
80 | *.manifest
81 | *.spec
82 |
83 | # Installer logs
84 | pip-log.txt
85 | pip-delete-this-directory.txt
86 |
87 | # Unit test / coverage reports
88 | htmlcov/
89 | .tox/
90 | .nox/
91 | .coverage
92 | .coverage.*
93 | .cache
94 | nosetests.xml
95 | coverage.xml
96 | *.cover
97 | *.py,cover
98 | .hypothesis/
99 | .pytest_cache/
100 | cover/
101 |
102 | # Translations
103 | *.mo
104 | *.pot
105 |
106 | # Django stuff:
107 | *.log
108 | local_settings.py
109 | db.sqlite3
110 | db.sqlite3-journal
111 |
112 | # Flask stuff:
113 | instance/
114 | .webassets-cache
115 |
116 | # Scrapy stuff:
117 | .scrapy
118 |
119 | # Sphinx documentation
120 | docs/_build/
121 |
122 | # PyBuilder
123 | .pybuilder/
124 | target/
125 |
126 | # Jupyter Notebook
127 | .ipynb_checkpoints
128 |
129 | # IPython
130 | profile_default/
131 | ipython_config.py
132 |
133 | # pyenv
134 | # For a library or package, you might want to ignore these files since the code is
135 | # intended to run in multiple environments; otherwise, check them in:
136 | # .python-version
137 |
138 | # pipenv
139 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
140 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
141 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
142 | # install all needed dependencies.
143 | #Pipfile.lock
144 |
145 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
146 | __pypackages__/
147 |
148 | # Celery stuff
149 | celerybeat-schedule
150 | celerybeat.pid
151 |
152 | # SageMath parsed files
153 | *.sage.py
154 |
155 | # Environments
156 | .env
157 | .venv
158 | env/
159 | venv/
160 | ENV/
161 | env.bak/
162 | venv.bak/
163 |
164 | # Spyder project settings
165 | .spyderproject
166 | .spyproject
167 |
168 | # Rope project settings
169 | .ropeproject
170 |
171 | # mkdocs documentation
172 | /site
173 |
174 | # mypy
175 | .mypy_cache/
176 | .dmypy.json
177 | dmypy.json
178 |
179 | # Pyre type checker
180 | .pyre/
181 |
182 | # pytype static type analyzer
183 | .pytype/
184 |
185 | # Cython debug symbols
186 | cython_debug/
187 |
188 | ### VisualStudioCode ###
189 | .vscode/*
190 | !.vscode/settings.json
191 | !.vscode/tasks.json
192 | !.vscode/launch.json
193 | !.vscode/extensions.json
194 | *.code-workspace
195 |
196 | # Local History for Visual Studio Code
197 | .history/
198 |
199 | ### VisualStudioCode Patch ###
200 | # Ignore all local history of files
201 | .history
202 | .ionide
203 |
204 | ### Windows ###
205 | # Windows thumbnail cache files
206 | Thumbs.db
207 | Thumbs.db:encryptable
208 | ehthumbs.db
209 | ehthumbs_vista.db
210 |
211 | # Dump file
212 | *.stackdump
213 |
214 | # Folder config file
215 | [Dd]esktop.ini
216 |
217 | # Recycle Bin used on file shares
218 | $RECYCLE.BIN/
219 |
220 | # Windows Installer files
221 | *.cab
222 | *.msi
223 | *.msix
224 | *.msm
225 | *.msp
226 |
227 | # Windows shortcuts
228 | *.lnk
229 |
230 | # End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
231 |
232 | # setuptools-scm _version file
233 | src/animatediff/_version.py
234 |
235 | # local misc and temp
236 | /misc/
237 | /temp/
238 |
239 | # envrc
240 | .env*
241 | !.envrc.example
242 |
--------------------------------------------------------------------------------
/src/animatediff/softmax_splatting/README.md:
--------------------------------------------------------------------------------
1 | # softmax-splatting
2 | This is a reference implementation of the softmax splatting operator, which has been proposed in Softmax Splatting for Video Frame Interpolation [1], using PyTorch. Softmax splatting is a well-motivated approach for differentiable forward warping. It uses a translational invariant importance metric to disambiguate cases where multiple source pixels map to the same target pixel. Should you be making use of our work, please cite our paper [1].
3 |
4 |
5 |
6 | For our previous work on SepConv, see: https://github.com/sniklaus/revisiting-sepconv
7 |
8 | ## setup
9 | The softmax splatting is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided [binary packages](https://docs.cupy.dev/en/stable/install.html#installing-cupy) as outlined in the CuPy repository.
10 |
11 | If you plan to process videos, then please also make sure to have `pip install moviepy` installed.
12 |
13 | ## usage
14 | To run it on your own pair of frames, use the following command.
15 |
16 | ```
17 | python run.py --model lf --one ./images/one.png --two ./images/two.png --out ./out.png
18 | ```
19 |
20 | To run in on a video, use the following command.
21 |
22 | ```
23 | python run.py --model lf --video ./videos/car-turn.mp4 --out ./out.mp4
24 | ```
25 |
26 | For a quick benchmark using examples from the Middlebury benchmark for optical flow, run `python benchmark_middlebury.py`. You can use it to easily verify that the provided implementation runs as expected.
27 |
28 | ## warping
29 | We provide a small script to replicate the third figure of our paper [1]. You can simply run the following to obtain the comparison between summation splatting, average splatting, linear splatting, and softmax splatting.
30 |
31 | The example script is using OpenCV to load and display images, as well as to read the provided optical flow file. An easy way to install OpenCV for Python is using the `pip install opencv-contrib-python` package.
32 |
33 | ```
34 | import cv2
35 | import numpy
36 | import torch
37 |
38 | import run
39 |
40 | import softsplat # the custom softmax splatting layer
41 |
42 | ##########################################################
43 |
44 | torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
45 |
46 | torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
47 |
48 | ##########################################################
49 |
50 | tenOne = torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/one.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
51 | tenTwo = torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/two.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
52 | tenFlow = torch.FloatTensor(numpy.ascontiguousarray(run.read_flo('./images/flow.flo').transpose(2, 0, 1)[None, :, :, :])).cuda()
53 |
54 | tenMetric = torch.nn.functional.l1_loss(input=tenOne, target=run.backwarp(tenIn=tenTwo, tenFlow=tenFlow), reduction='none').mean([1], True)
55 |
56 | for intTime, fltTime in enumerate(numpy.linspace(0.0, 1.0, 11).tolist()):
57 | tenSummation = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=None, strMode='sum')
58 | tenAverage = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=None, strMode='avg')
59 | tenLinear = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=(0.3 - tenMetric).clip(0.001, 1.0), strMode='linear') # finding a good linearly metric is difficult, and it is not invariant to translations
60 | tenSoftmax = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=(-20.0 * tenMetric).clip(-20.0, 20.0), strMode='soft') # -20.0 is a hyperparameter, called 'alpha' in the paper, that could be learned using a torch.Parameter
61 |
62 | cv2.imshow(winname='summation', mat=tenSummation[0, :, :, :].cpu().numpy().transpose(1, 2, 0))
63 | cv2.imshow(winname='average', mat=tenAverage[0, :, :, :].cpu().numpy().transpose(1, 2, 0))
64 | cv2.imshow(winname='linear', mat=tenLinear[0, :, :, :].cpu().numpy().transpose(1, 2, 0))
65 | cv2.imshow(winname='softmax', mat=tenSoftmax[0, :, :, :].cpu().numpy().transpose(1, 2, 0))
66 | cv2.waitKey(delay=0)
67 | # end
68 | ```
69 |
70 | ## xiph
71 | In our paper, we propose to use 4K video clips from Xiph to evaluate video frame interpolation on high-resolution footage. Please see the supplementary `benchmark_xiph.py` on how to reproduce the shown metrics.
72 |
73 | ## video
74 |
75 |
76 | ## license
77 | The provided implementation is strictly for academic purposes only. Should you be interested in using our technology for any commercial use, please feel free to contact us.
78 |
79 | ## references
80 | ```
81 | [1] @inproceedings{Niklaus_CVPR_2020,
82 | author = {Simon Niklaus and Feng Liu},
83 | title = {Softmax Splatting for Video Frame Interpolation},
84 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
85 | year = {2020}
86 | }
87 | ```
88 |
89 | ## acknowledgment
90 | The video above uses materials under a Creative Common license as detailed at the end.
--------------------------------------------------------------------------------
/src/animatediff/utils/huggingface.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from os import PathLike
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
7 | from huggingface_hub import hf_hub_download, snapshot_download
8 | from tqdm.rich import tqdm
9 |
10 | from animatediff import HF_HUB_CACHE, HF_LIB_NAME, HF_LIB_VER, get_dir
11 | from animatediff.utils.util import path_from_cwd
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | data_dir = get_dir("data")
16 | checkpoint_dir = data_dir.joinpath("models/sd")
17 | pipeline_dir = data_dir.joinpath("models/huggingface")
18 |
19 | IGNORE_TF = ["*.git*", "*.h5", "tf_*"]
20 | IGNORE_FLAX = ["*.git*", "flax_*", "*.msgpack"]
21 | IGNORE_TF_FLAX = IGNORE_TF + IGNORE_FLAX
22 |
23 |
24 | class DownloadTqdm(tqdm):
25 | def __init__(self, *args, **kwargs):
26 | kwargs.update(
27 | {
28 | "ncols": 100,
29 | "dynamic_ncols": False,
30 | "disable": None,
31 | }
32 | )
33 | super().__init__(*args, **kwargs)
34 |
35 |
36 | def get_hf_file(
37 | repo_id: Path,
38 | filename: str,
39 | target_dir: Path,
40 | subfolder: Optional[PathLike] = None,
41 | revision: Optional[str] = None,
42 | force: bool = False,
43 | ) -> Path:
44 | target_path = target_dir.joinpath(filename)
45 | if target_path.exists() and force is not True:
46 | raise FileExistsError(
47 | f"File {path_from_cwd(target_path)} already exists! Pass force=True to overwrite"
48 | )
49 |
50 | target_dir.mkdir(exist_ok=True, parents=True)
51 | save_path = hf_hub_download(
52 | repo_id=str(repo_id),
53 | filename=filename,
54 | revision=revision or "main",
55 | subfolder=subfolder,
56 | local_dir=target_dir,
57 | local_dir_use_symlinks=False,
58 | cache_dir=HF_HUB_CACHE,
59 | resume_download=True,
60 | )
61 | return Path(save_path)
62 |
63 |
64 | def get_hf_repo(
65 | repo_id: Path,
66 | target_dir: Path,
67 | subfolder: Optional[PathLike] = None,
68 | revision: Optional[str] = None,
69 | force: bool = False,
70 | ) -> Path:
71 | if target_dir.exists() and force is not True:
72 | raise FileExistsError(
73 | f"Target dir {path_from_cwd(target_dir)} already exists! Pass force=True to overwrite"
74 | )
75 |
76 | target_dir.mkdir(exist_ok=True, parents=True)
77 | save_path = snapshot_download(
78 | repo_id=str(repo_id),
79 | revision=revision or "main",
80 | subfolder=subfolder,
81 | library_name=HF_LIB_NAME,
82 | library_version=HF_LIB_VER,
83 | local_dir=target_dir,
84 | local_dir_use_symlinks=False,
85 | ignore_patterns=IGNORE_TF_FLAX,
86 | cache_dir=HF_HUB_CACHE,
87 | tqdm_class=DownloadTqdm,
88 | max_workers=2,
89 | resume_download=True,
90 | )
91 | return Path(save_path)
92 |
93 |
94 | def get_hf_pipeline(
95 | repo_id: Path,
96 | target_dir: Path,
97 | save: bool = True,
98 | force_download: bool = False,
99 | ) -> StableDiffusionPipeline:
100 | pipeline_exists = target_dir.joinpath("model_index.json").exists()
101 | if pipeline_exists and force_download is not True:
102 | pipeline = StableDiffusionPipeline.from_pretrained(
103 | pretrained_model_name_or_path=target_dir,
104 | local_files_only=True,
105 | )
106 | else:
107 | target_dir.mkdir(exist_ok=True, parents=True)
108 | pipeline = StableDiffusionPipeline.from_pretrained(
109 | pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"),
110 | cache_dir=HF_HUB_CACHE,
111 | resume_download=True,
112 | )
113 | if save and force_download:
114 | logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!")
115 | pipeline.save_pretrained(target_dir, safe_serialization=True)
116 | elif save and not pipeline_exists:
117 | logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
118 | pipeline.save_pretrained(target_dir, safe_serialization=True)
119 | return pipeline
120 |
121 | def get_hf_pipeline_sdxl(
122 | repo_id: Path,
123 | target_dir: Path,
124 | save: bool = True,
125 | force_download: bool = False,
126 | ) -> StableDiffusionXLPipeline:
127 | import torch
128 | pipeline_exists = target_dir.joinpath("model_index.json").exists()
129 | if pipeline_exists and force_download is not True:
130 | pipeline = StableDiffusionXLPipeline.from_pretrained(
131 | pretrained_model_name_or_path=target_dir,
132 | local_files_only=True,
133 | torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
134 | )
135 | else:
136 | target_dir.mkdir(exist_ok=True, parents=True)
137 | pipeline = StableDiffusionXLPipeline.from_pretrained(
138 | pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"),
139 | cache_dir=HF_HUB_CACHE,
140 | resume_download=True,
141 | torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
142 | )
143 | if save and force_download:
144 | logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!")
145 | pipeline.save_pretrained(target_dir, safe_serialization=True)
146 | elif save and not pipeline_exists:
147 | logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
148 | pipeline.save_pretrained(target_dir, safe_serialization=True)
149 | return pipeline
150 |
--------------------------------------------------------------------------------
/src/animatediff/ip_adapter/resampler.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 | from einops import rearrange
8 | from einops.layers.torch import Rearrange
9 |
10 |
11 | # FFN
12 | def FeedForward(dim, mult=4):
13 | inner_dim = int(dim * mult)
14 | return nn.Sequential(
15 | nn.LayerNorm(dim),
16 | nn.Linear(dim, inner_dim, bias=False),
17 | nn.GELU(),
18 | nn.Linear(inner_dim, dim, bias=False),
19 | )
20 |
21 |
22 | def reshape_tensor(x, heads):
23 | bs, length, width = x.shape
24 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
25 | x = x.view(bs, length, heads, -1)
26 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
27 | x = x.transpose(1, 2)
28 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
29 | x = x.reshape(bs, heads, length, -1)
30 | return x
31 |
32 |
33 | class PerceiverAttention(nn.Module):
34 | def __init__(self, *, dim, dim_head=64, heads=8):
35 | super().__init__()
36 | self.scale = dim_head**-0.5
37 | self.dim_head = dim_head
38 | self.heads = heads
39 | inner_dim = dim_head * heads
40 |
41 | self.norm1 = nn.LayerNorm(dim)
42 | self.norm2 = nn.LayerNorm(dim)
43 |
44 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
45 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
46 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
47 |
48 |
49 | def forward(self, x, latents):
50 | """
51 | Args:
52 | x (torch.Tensor): image features
53 | shape (b, n1, D)
54 | latent (torch.Tensor): latent features
55 | shape (b, n2, D)
56 | """
57 | x = self.norm1(x)
58 | latents = self.norm2(latents)
59 |
60 | b, l, _ = latents.shape
61 |
62 | q = self.to_q(latents)
63 | kv_input = torch.cat((x, latents), dim=-2)
64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65 |
66 | q = reshape_tensor(q, self.heads)
67 | k = reshape_tensor(k, self.heads)
68 | v = reshape_tensor(v, self.heads)
69 |
70 | # attention
71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74 | out = weight @ v
75 |
76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77 |
78 | return self.to_out(out)
79 |
80 |
81 | class Resampler(nn.Module):
82 | def __init__(
83 | self,
84 | dim=1024,
85 | depth=8,
86 | dim_head=64,
87 | heads=16,
88 | num_queries=8,
89 | embedding_dim=768,
90 | output_dim=1024,
91 | ff_mult=4,
92 | max_seq_len: int = 257, # CLIP tokens + CLS token
93 | apply_pos_emb: bool = False,
94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95 | ):
96 | super().__init__()
97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98 |
99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100 |
101 | self.proj_in = nn.Linear(embedding_dim, dim)
102 |
103 | self.proj_out = nn.Linear(dim, output_dim)
104 | self.norm_out = nn.LayerNorm(output_dim)
105 |
106 | self.to_latents_from_mean_pooled_seq = (
107 | nn.Sequential(
108 | nn.LayerNorm(dim),
109 | nn.Linear(dim, dim * num_latents_mean_pooled),
110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111 | )
112 | if num_latents_mean_pooled > 0
113 | else None
114 | )
115 |
116 | self.layers = nn.ModuleList([])
117 | for _ in range(depth):
118 | self.layers.append(
119 | nn.ModuleList(
120 | [
121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122 | FeedForward(dim=dim, mult=ff_mult),
123 | ]
124 | )
125 | )
126 |
127 | def forward(self, x):
128 | if self.pos_emb is not None:
129 | n, device = x.shape[1], x.device
130 | pos_emb = self.pos_emb(torch.arange(n, device=device))
131 | x = x + pos_emb
132 |
133 | latents = self.latents.repeat(x.size(0), 1, 1)
134 |
135 | x = self.proj_in(x)
136 |
137 | if self.to_latents_from_mean_pooled_seq:
138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140 | latents = torch.cat((meanpooled_latents, latents), dim=-2)
141 |
142 | for attn, ff in self.layers:
143 | latents = attn(x, latents) + latents
144 | latents = ff(latents) + latents
145 |
146 | latents = self.proj_out(latents)
147 | return self.norm_out(latents)
148 |
149 |
150 | def masked_mean(t, *, dim, mask=None):
151 | if mask is None:
152 | return t.mean(dim=dim)
153 |
154 | denom = mask.sum(dim=dim, keepdim=True)
155 | mask = rearrange(mask, "b n -> b n 1")
156 | masked_t = t.masked_fill(~mask, 0.0)
157 |
158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
159 |
--------------------------------------------------------------------------------
/src/animatediff/utils/convert_lora_safetensor_to_diffusers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ Conversion script for the LoRA's safetensors checkpoints. """
17 |
18 | import argparse
19 |
20 | import torch
21 |
22 |
23 | def convert_lora(
24 | pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6
25 | ):
26 | # load base model
27 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
28 |
29 | # load LoRA weight from .safetensors
30 | # state_dict = load_file(checkpoint_path)
31 |
32 | visited = []
33 |
34 | # directly update weight in diffusers model
35 | for key in state_dict:
36 | # it is suggested to print out the key, it usually will be something like below
37 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
38 |
39 | # as we have set the alpha beforehand, so just skip
40 | if ".alpha" in key or key in visited:
41 | continue
42 |
43 | if "text" in key:
44 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
45 | curr_layer = pipeline.text_encoder
46 | else:
47 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
48 | curr_layer = pipeline.unet
49 |
50 | # find the target layer
51 | temp_name = layer_infos.pop(0)
52 | while len(layer_infos) > -1:
53 | try:
54 | curr_layer = curr_layer.__getattr__(temp_name)
55 | if len(layer_infos) > 0:
56 | temp_name = layer_infos.pop(0)
57 | elif len(layer_infos) == 0:
58 | break
59 | except Exception:
60 | if len(temp_name) > 0:
61 | temp_name += "_" + layer_infos.pop(0)
62 | else:
63 | temp_name = layer_infos.pop(0)
64 |
65 | pair_keys = []
66 | if "lora_down" in key:
67 | pair_keys.append(key.replace("lora_down", "lora_up"))
68 | pair_keys.append(key)
69 | else:
70 | pair_keys.append(key)
71 | pair_keys.append(key.replace("lora_up", "lora_down"))
72 |
73 | # update weight
74 | if len(state_dict[pair_keys[0]].shape) == 4:
75 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
76 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
77 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(
78 | curr_layer.weight.data.device
79 | )
80 | else:
81 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
82 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
83 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(
84 | curr_layer.weight.data.device
85 | )
86 |
87 | # update visited list
88 | for item in pair_keys:
89 | visited.append(item)
90 |
91 | return pipeline
92 |
93 |
94 | if __name__ == "__main__":
95 | parser = argparse.ArgumentParser()
96 |
97 | parser.add_argument(
98 | "--base_model_path",
99 | default=None,
100 | type=str,
101 | required=True,
102 | help="Path to the base model in diffusers format.",
103 | )
104 | parser.add_argument(
105 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
106 | )
107 | parser.add_argument(
108 | "--dump_path", default=None, type=str, required=True, help="Path to the output model."
109 | )
110 | parser.add_argument(
111 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
112 | )
113 | parser.add_argument(
114 | "--lora_prefix_text_encoder",
115 | default="lora_te",
116 | type=str,
117 | help="The prefix of text encoder weight in safetensors",
118 | )
119 | parser.add_argument(
120 | "--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW"
121 | )
122 | parser.add_argument(
123 | "--to_safetensors",
124 | action="store_true",
125 | help="Whether to store pipeline in safetensors format or not.",
126 | )
127 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
128 |
129 | args = parser.parse_args()
130 |
131 | base_model_path = args.base_model_path
132 | checkpoint_path = args.checkpoint_path
133 | dump_path = args.dump_path
134 | lora_prefix_unet = args.lora_prefix_unet
135 | lora_prefix_text_encoder = args.lora_prefix_text_encoder
136 | alpha = args.alpha
137 |
138 | pipe = convert_lora(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
139 |
140 | pipe = pipe.to(args.device)
141 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
142 |
--------------------------------------------------------------------------------
/src/animatediff/settings.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from os import PathLike
4 | from pathlib import Path
5 | from typing import Any, Dict, Optional, Tuple, Union
6 |
7 | from pydantic.v1 import BaseConfig, BaseSettings, Field
8 | from pydantic.env_settings import (EnvSettingsSource, InitSettingsSource,
9 | SecretsSettingsSource,
10 | SettingsSourceCallable)
11 |
12 | from animatediff import get_dir
13 | from animatediff.schedulers import DiffusionScheduler
14 |
15 | logging.basicConfig(level=logging.INFO)
16 | logger = logging.getLogger(__name__)
17 |
18 | CKPT_EXTENSIONS = [".pt", ".ckpt", ".pth", ".safetensors"]
19 |
20 |
21 | class JsonSettingsSource:
22 | __slots__ = ["json_config_path"]
23 |
24 | def __init__(
25 | self,
26 | json_config_path: Optional[Union[PathLike, list[PathLike]]] = list(),
27 | ) -> None:
28 | if isinstance(json_config_path, list):
29 | self.json_config_path = [Path(path) for path in json_config_path]
30 | else:
31 | self.json_config_path = [Path(json_config_path)] if json_config_path is not None else []
32 |
33 | def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901
34 | classname = settings.__class__.__name__
35 | encoding = settings.__config__.env_file_encoding
36 | if len(self.json_config_path) == 0:
37 | pass # no json config provided
38 |
39 | merged_config = dict() # create an empty dict to merge configs into
40 | for idx, path in enumerate(self.json_config_path):
41 | if path.exists() and path.is_file(): # check if the path exists and is a file
42 | logger.debug(f"{classname}: loading config #{idx+1} from {path}")
43 | merged_config.update(json.loads(path.read_text(encoding=encoding)))
44 | logger.debug(f"{classname}: config state #{idx+1}: {merged_config}")
45 | else:
46 | raise FileNotFoundError(f"{classname}: config #{idx+1} at {path} not found or not a file")
47 |
48 | logger.debug(f"{classname}: loaded config: {merged_config}")
49 | return merged_config # return the merged config
50 |
51 | def __repr__(self) -> str:
52 | return f"JsonSettingsSource(json_config_path={repr(self.json_config_path)})"
53 |
54 |
55 | class JsonConfig(BaseConfig):
56 | json_config_path: Optional[Union[Path, list[Path]]] = None
57 | env_file_encoding: str = "utf-8"
58 |
59 | @classmethod
60 | def customise_sources(
61 | cls,
62 | init_settings: InitSettingsSource,
63 | env_settings: EnvSettingsSource,
64 | file_secret_settings: SecretsSettingsSource,
65 | ) -> Tuple[SettingsSourceCallable, ...]:
66 | # pull json_config_path from init_settings if passed, otherwise use the class var
67 | json_config_path = init_settings.init_kwargs.pop("json_config_path", cls.json_config_path)
68 |
69 | logger.debug(f"Using JsonSettingsSource for {cls.__name__}")
70 | json_settings = JsonSettingsSource(json_config_path=json_config_path)
71 |
72 | # return the new settings sources
73 | return (
74 | init_settings,
75 | json_settings,
76 | )
77 |
78 |
79 | class InferenceConfig(BaseSettings):
80 | unet_additional_kwargs: dict[str, Any]
81 | noise_scheduler_kwargs: dict[str, Any]
82 |
83 | class Config(JsonConfig):
84 | json_config_path: Path
85 |
86 |
87 | def get_infer_config(
88 | is_v2:bool,
89 | is_sdxl:bool,
90 | ) -> InferenceConfig:
91 | config_path: Path = get_dir("config").joinpath("inference/default.json" if not is_v2 else "inference/motion_v2.json")
92 |
93 | if is_sdxl:
94 | config_path = get_dir("config").joinpath("inference/motion_sdxl.json")
95 |
96 | settings = InferenceConfig(json_config_path=config_path)
97 | return settings
98 |
99 |
100 | class ModelConfig(BaseSettings):
101 | name: str = Field(...) # Config name, not actually used for much of anything
102 | path: Path = Field(...) # Path to the model
103 | vae_path: str = "" # Path to the model
104 | motion_module: Path = Field(...) # Path to the motion module
105 | context_schedule: str = "uniform"
106 | lcm_map: Dict[str,Any]= Field({})
107 | gradual_latent_hires_fix_map: Dict[str,Any]= Field({})
108 | compile: bool = Field(False) # whether to compile the model with TorchDynamo
109 | tensor_interpolation_slerp: bool = Field(True)
110 | seed: list[int] = Field([]) # Seed(s) for the random number generators
111 | scheduler: DiffusionScheduler = Field(DiffusionScheduler.k_dpmpp_2m) # Scheduler to use
112 | steps: int = 25 # Number of inference steps to run
113 | guidance_scale: float = 7.5 # CFG scale to use
114 | unet_batch_size: int = 1
115 | clip_skip: int = 1 # skip the last N-1 layers of the CLIP text encoder
116 | prompt_fixed_ratio: float = 0.5
117 | head_prompt: str = ""
118 | prompt_map: Dict[str,str]= Field({})
119 | tail_prompt: str = ""
120 | n_prompt: list[str] = Field([]) # Anti-prompt(s) to use
121 | is_single_prompt_mode : bool = Field(False)
122 | lora_map: Dict[str,Any]= Field({})
123 | motion_lora_map: Dict[str,float]= Field({})
124 | ip_adapter_map: Dict[str,Any]= Field({})
125 | img2img_map: Dict[str,Any]= Field({})
126 | region_map: Dict[str,Any]= Field({})
127 | controlnet_map: Dict[str,Any]= Field({})
128 | upscale_config: Dict[str,Any]= Field({})
129 | stylize_config: Dict[str,Any]= Field({})
130 | output: Dict[str,Any]= Field({})
131 | result: Dict[str,Any]= Field({})
132 |
133 | class Config(JsonConfig):
134 | json_config_path: Path
135 |
136 | @property
137 | def save_name(self):
138 | return f"{self.name.lower()}-{self.path.stem.lower()}"
139 |
140 |
141 | def get_model_config(config_path: Path) -> ModelConfig:
142 | settings = ModelConfig(json_config_path=config_path)
143 | return settings
144 |
--------------------------------------------------------------------------------
/src/animatediff/utils/tagger.py:
--------------------------------------------------------------------------------
1 | # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
2 |
3 | import glob
4 | import logging
5 | import os
6 |
7 | import cv2
8 | import numpy as np
9 | import onnxruntime
10 | import pandas as pd
11 | from PIL import Image
12 | from tqdm.rich import tqdm
13 |
14 | from animatediff.utils.util import prepare_wd14tagger
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def make_square(img, target_size):
20 | old_size = img.shape[:2]
21 | desired_size = max(old_size)
22 | desired_size = max(desired_size, target_size)
23 |
24 | delta_w = desired_size - old_size[1]
25 | delta_h = desired_size - old_size[0]
26 | top, bottom = delta_h // 2, delta_h - (delta_h // 2)
27 | left, right = delta_w // 2, delta_w - (delta_w // 2)
28 |
29 | color = [255, 255, 255]
30 | new_im = cv2.copyMakeBorder(
31 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
32 | )
33 | return new_im
34 |
35 | def smart_resize(img, size):
36 | # Assumes the image has already gone through make_square
37 | if img.shape[0] > size:
38 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
39 | elif img.shape[0] < size:
40 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
41 | return img
42 |
43 |
44 | class Tagger:
45 | def __init__(self, general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format,is_cpu):
46 | prepare_wd14tagger()
47 | # self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CUDAExecutionProvider','CPUExecutionProvider'])
48 | if is_cpu:
49 | self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CPUExecutionProvider'])
50 | else:
51 | self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CUDAExecutionProvider'])
52 | df = pd.read_csv("data/models/WD14tagger/selected_tags.csv")
53 | self.tag_names = df["name"].tolist()
54 | self.rating_indexes = list(np.where(df["category"] == 9)[0])
55 | self.general_indexes = list(np.where(df["category"] == 0)[0])
56 | self.character_indexes = list(np.where(df["category"] == 4)[0])
57 |
58 | self.general_threshold = general_threshold
59 | self.character_threshold = character_threshold
60 | self.ignore_tokens = ignore_tokens
61 | self.with_confidence = with_confidence
62 | self.is_danbooru_format = is_danbooru_format
63 |
64 | def __call__(
65 | self,
66 | image: Image,
67 | ):
68 |
69 | _, height, width, _ = self.model.get_inputs()[0].shape
70 |
71 | # Alpha to white
72 | image = image.convert("RGBA")
73 | new_image = Image.new("RGBA", image.size, "WHITE")
74 | new_image.paste(image, mask=image)
75 | image = new_image.convert("RGB")
76 | image = np.asarray(image)
77 |
78 | # PIL RGB to OpenCV BGR
79 | image = image[:, :, ::-1]
80 |
81 | image = make_square(image, height)
82 | image = smart_resize(image, height)
83 | image = image.astype(np.float32)
84 | image = np.expand_dims(image, 0)
85 |
86 | input_name = self.model.get_inputs()[0].name
87 | label_name = self.model.get_outputs()[0].name
88 | probs = self.model.run([label_name], {input_name: image})[0]
89 |
90 | labels = list(zip(self.tag_names, probs[0].astype(float)))
91 |
92 | # First 4 labels are actually ratings: pick one with argmax
93 | ratings_names = [labels[i] for i in self.rating_indexes]
94 | rating = dict(ratings_names)
95 |
96 | # Then we have general tags: pick any where prediction confidence > threshold
97 | general_names = [labels[i] for i in self.general_indexes]
98 | general_res = [x for x in general_names if x[1] > self.general_threshold]
99 | general_res = dict(general_res)
100 |
101 | # Everything else is characters: pick any where prediction confidence > threshold
102 | character_names = [labels[i] for i in self.character_indexes]
103 | character_res = [x for x in character_names if x[1] > self.character_threshold]
104 | character_res = dict(character_res)
105 |
106 | #logger.info(f"{rating=}")
107 | #logger.info(f"{general_res=}")
108 | #logger.info(f"{character_res=}")
109 |
110 | general_res = {k:general_res[k] for k in (general_res.keys() - set(self.ignore_tokens)) }
111 | character_res = {k:character_res[k] for k in (character_res.keys() - set(self.ignore_tokens)) }
112 |
113 | prompt = ""
114 |
115 | if self.with_confidence:
116 | prompt = [ f"({i}:{character_res[i]:.2f})" for i in (character_res.keys()) ]
117 | prompt += [ f"({i}:{general_res[i]:.2f})" for i in (general_res.keys()) ]
118 | else:
119 | prompt = [ i for i in (character_res.keys()) ]
120 | prompt += [ i for i in (general_res.keys()) ]
121 |
122 | prompt = ",".join(prompt)
123 |
124 | if not self.is_danbooru_format:
125 | prompt = prompt.replace("_", " ")
126 |
127 | #logger.info(f"{prompt=}")
128 | return prompt
129 |
130 |
131 | def get_labels(frame_dir, interval, general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format, is_cpu =False):
132 |
133 | import torch
134 |
135 | result = {}
136 | if os.path.isdir(frame_dir):
137 | png_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False))
138 |
139 | png_map ={}
140 | for png_path in png_list:
141 | basename_without_ext = os.path.splitext(os.path.basename(png_path))[0]
142 | png_map[int(basename_without_ext)] = png_path
143 |
144 | with torch.no_grad():
145 | tagger = Tagger(general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format, is_cpu)
146 |
147 | for i in tqdm(range(0, len(png_list), interval ), desc=f"WD14tagger"):
148 | path = png_map[i]
149 |
150 | #logger.info(f"{path=}")
151 |
152 | result[str(i)] = tagger(
153 | image= Image.open(path)
154 | )
155 |
156 | tagger = None
157 |
158 | torch.cuda.empty_cache()
159 |
160 | return result
161 |
162 |
--------------------------------------------------------------------------------
/src/animatediff/utils/composite.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import logging
3 | import os
4 | import shutil
5 | from pathlib import Path
6 |
7 | import cv2
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from PIL import Image
12 | from tqdm.rich import tqdm
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | #https://github.com/jinwonkim93/laplacian-pyramid-blend
18 | #https://blog.shikoan.com/pytorch-laplacian-pyramid/
19 | class LaplacianPyramidBlender:
20 |
21 | device = None
22 |
23 | def get_gaussian_kernel(self):
24 | kernel = np.array([
25 | [1, 4, 6, 4, 1],
26 | [4, 16, 24, 16, 4],
27 | [6, 24, 36, 24, 6],
28 | [4, 16, 24, 16, 4],
29 | [1, 4, 6, 4, 1]], np.float32) / 256.0
30 | gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5),device=self.device)
31 | return gaussian_k
32 |
33 | def pyramid_down(self, image):
34 | with torch.no_grad():
35 | gaussian_k = self.get_gaussian_kernel()
36 | multiband = [F.conv2d(image[:, i:i + 1,:,:], gaussian_k, padding=2, stride=2) for i in range(3)]
37 | down_image = torch.cat(multiband, dim=1)
38 | return down_image
39 |
40 | def pyramid_up(self, image, size = None):
41 | with torch.no_grad():
42 | gaussian_k = self.get_gaussian_kernel()
43 | if size is None:
44 | upsample = F.interpolate(image, scale_factor=2)
45 | else:
46 | upsample = F.interpolate(image, size=size)
47 | multiband = [F.conv2d(upsample[:, i:i + 1,:,:], gaussian_k, padding=2) for i in range(3)]
48 | up_image = torch.cat(multiband, dim=1)
49 | return up_image
50 |
51 | def gaussian_pyramid(self, original, n_pyramids):
52 | x = original
53 | # pyramid down
54 | pyramids = [original]
55 | for i in range(n_pyramids):
56 | x = self.pyramid_down(x)
57 | pyramids.append(x)
58 | return pyramids
59 |
60 | def laplacian_pyramid(self, original, n_pyramids):
61 | pyramids = self.gaussian_pyramid(original, n_pyramids)
62 |
63 | # pyramid up - diff
64 | laplacian = []
65 | for i in range(len(pyramids) - 1):
66 | diff = pyramids[i] - self.pyramid_up(pyramids[i + 1], pyramids[i].shape[2:])
67 | laplacian.append(diff)
68 |
69 | laplacian.append(pyramids[-1])
70 | return laplacian
71 |
72 | def laplacian_pyramid_blending_with_mask(self, src, target, mask, num_levels = 9):
73 | # assume mask is float32 [0,1]
74 |
75 | # generate Gaussian pyramid for src,target and mask
76 |
77 | Gsrc = torch.as_tensor(np.expand_dims(src, axis=0), device=self.device)
78 | Gtarget = torch.as_tensor(np.expand_dims(target, axis=0), device=self.device)
79 | Gmask = torch.as_tensor(np.expand_dims(mask, axis=0), device=self.device)
80 |
81 | lpA = self.laplacian_pyramid(Gsrc,num_levels)[::-1]
82 | lpB = self.laplacian_pyramid(Gtarget,num_levels)[::-1]
83 | gpMr = self.gaussian_pyramid(Gmask,num_levels)[::-1]
84 |
85 | # Now blend images according to mask in each level
86 | LS = []
87 | for idx, (la,lb,Gmask) in enumerate(zip(lpA,lpB,gpMr)):
88 | lo = lb * (1.0 - Gmask)
89 | if idx <= 2:
90 | lo += lb * Gmask
91 | else:
92 | lo += la * Gmask
93 | LS.append(lo)
94 |
95 | # now reconstruct
96 | ls_ = LS.pop(0)
97 | for lap in LS:
98 | ls_ = self.pyramid_up(ls_, lap.shape[2:]) + lap
99 |
100 | result = ls_.squeeze(dim=0).to('cpu').detach().numpy().copy()
101 |
102 | return result
103 |
104 | def __call__(self,
105 | src_image: np.ndarray,
106 | target_image: np.ndarray,
107 | mask_image: np.ndarray,
108 | device
109 | ):
110 |
111 | self.device = device
112 |
113 | num_levels = int(np.log2(src_image.shape[0]))
114 | #normalize image to 0, 1
115 | mask_image = np.clip(mask_image, 0, 1).transpose([2, 0, 1])
116 |
117 | src_image = src_image.transpose([2, 0, 1]).astype(np.float32) / 255.0
118 | target_image = target_image.transpose([2, 0, 1]).astype(np.float32) / 255.0
119 | composite_image = self.laplacian_pyramid_blending_with_mask(src_image, target_image, mask_image, num_levels)
120 | composite_image = np.clip(composite_image*255, 0 , 255).astype(np.uint8)
121 | composite_image=composite_image.transpose([1, 2, 0])
122 | return composite_image
123 |
124 |
125 | def composite(bg_dir, fg_list, output_dir, masked_area_list, device="cuda"):
126 | bg_list = sorted(glob.glob( os.path.join(bg_dir ,"[0-9]*.png"), recursive=False))
127 |
128 | blender = LaplacianPyramidBlender()
129 |
130 | for bg, fg_array, mask in tqdm(zip(bg_list, fg_list, masked_area_list),total=len(bg_list), desc="compositing"):
131 | name = Path(bg).name
132 | save_path = output_dir / name
133 |
134 | if fg_array is None:
135 | logger.info(f"composite fg_array is None -> skip")
136 | shutil.copy(bg, save_path)
137 | continue
138 |
139 | if mask is None:
140 | logger.info(f"mask is None -> skip")
141 | shutil.copy(bg, save_path)
142 | continue
143 |
144 | bg = np.asarray(Image.open(bg)).copy()
145 | fg = fg_array
146 | mask = np.concatenate([mask, mask, mask], 2)
147 |
148 | h, w, _ = bg.shape
149 |
150 | fg = cv2.resize(fg, dsize=(w,h))
151 | mask = cv2.resize(mask, dsize=(w,h))
152 |
153 |
154 | mask = mask.astype(np.float32)
155 | # mask = mask * 255
156 | mask = cv2.GaussianBlur(mask, (15, 15), 0)
157 | mask = mask / 255
158 |
159 | fg = fg * mask + bg * (1-mask)
160 |
161 | img = blender(fg, bg, mask,device)
162 |
163 |
164 | img = Image.fromarray(img)
165 | img.save(save_path)
166 |
167 | def simple_composite(bg_dir, fg_list, output_dir, masked_area_list, device="cuda"):
168 | bg_list = sorted(glob.glob( os.path.join(bg_dir ,"[0-9]*.png"), recursive=False))
169 |
170 | for bg, fg_array, mask in tqdm(zip(bg_list, fg_list, masked_area_list),total=len(bg_list), desc="compositing"):
171 | name = Path(bg).name
172 | save_path = output_dir / name
173 |
174 | if fg_array is None:
175 | logger.info(f"composite fg_array is None -> skip")
176 | shutil.copy(bg, save_path)
177 | continue
178 |
179 | if mask is None:
180 | logger.info(f"mask is None -> skip")
181 | shutil.copy(bg, save_path)
182 | continue
183 |
184 | bg = np.asarray(Image.open(bg)).copy()
185 | fg = fg_array
186 | mask = np.concatenate([mask, mask, mask], 2)
187 |
188 | h, w, _ = bg.shape
189 |
190 | fg = cv2.resize(fg, dsize=(w,h))
191 | mask = cv2.resize(mask, dsize=(w,h))
192 |
193 |
194 | mask = mask.astype(np.float32)
195 | mask = cv2.GaussianBlur(mask, (15, 15), 0)
196 | mask = mask / 255
197 |
198 | img = fg * mask + bg * (1-mask)
199 | img = img.clip(0 , 255).astype(np.uint8)
200 |
201 | img = Image.fromarray(img)
202 | img.save(save_path)
--------------------------------------------------------------------------------
/example.md:
--------------------------------------------------------------------------------
1 | ### Example
2 |
3 | - region prompt(txt2img / no controlnet)
4 | - region 0 ... 1girl, upper body etc
5 | - region 1 ... ((car)), street, road,no human etc
6 | - background ... town, outdoors etc
7 | - ip adapter input for background / region 0 / region 1
8 |
9 |
10 | - animatediff generate -c config/prompts/region_txt2img.json -W 512 -H 768 -L 32 -C 16
11 | - region 0 mask / region 1 mask / txt2img
12 |
13 |