├── src └── animatediff │ ├── utils │ ├── __init__.py │ ├── wild_card.py │ ├── mask_rembg.py │ ├── mask_animseg.py │ ├── civitai2config.py │ ├── device.py │ ├── pipeline.py │ ├── huggingface.py │ ├── convert_lora_safetensor_to_diffusers.py │ ├── tagger.py │ └── composite.py │ ├── models │ ├── __init__.py │ ├── clip.py │ └── resnet.py │ ├── repo │ └── .gitignore │ ├── rife │ ├── __init__.py │ ├── ncnn.py │ ├── rife.py │ └── ffmpeg.py │ ├── __main__.py │ ├── ip_adapter │ ├── __init__.py │ └── resampler.py │ ├── softmax_splatting │ ├── correlation │ │ └── README.md │ └── README.md │ ├── pipelines │ ├── __init__.py │ ├── context.py │ ├── ti.py │ └── lora.py │ ├── dwpose │ ├── wholebody.py │ ├── __init__.py │ └── onnxdet.py │ ├── __init__.py │ ├── schedulers.py │ └── settings.py ├── setup.py ├── config ├── prompts │ ├── ignore_tokens.txt │ ├── to_8fps_Frames.bat │ ├── concat_2horizontal.bat │ ├── copy_png.bat │ ├── 01-ToonYou.json │ ├── 04-MajicMix.json │ ├── 08-GhibliBackground.json │ ├── 03-RcnzCartoon.json │ ├── 06-Tusun.json │ ├── 05-RealisticVision.json │ ├── 07-FilmVelvia.json │ ├── 02-Lyriel.json │ ├── prompt_travel_multi_controlnet.json │ └── img2img_sample.json ├── inference │ ├── motion_sdxl.json │ ├── default.json │ ├── motion_v2.json │ ├── sd15-unet3d.json │ └── sd15-unet.json └── GroundingDINO │ ├── GroundingDINO_SwinB_cfg.py │ └── GroundingDINO_SwinT_OGC.py ├── MANIFEST.in ├── scripts ├── download │ ├── 11-ToonYou.sh │ ├── 12-Lyriel.sh │ ├── 14-MajicMix.sh │ ├── 13-RcnzCartoon.sh │ ├── 15-RealisticVision.sh │ ├── 16-Tusun.sh │ ├── 17-FilmVelvia.sh │ ├── 18-GhibliBackground.sh │ ├── 03-BaseSD.py │ ├── 01-Motion-Modules.sh │ ├── sd-models.aria2 │ └── 02-All-SD-Models.sh └── test_persistent.py ├── data └── models │ ├── WD14tagger │ └── model.onnx │ └── README.md ├── .editorconfig ├── pyproject.toml ├── .pre-commit-config.yaml ├── setup.cfg ├── .vscode └── settings.json ├── test.py ├── app.py ├── requirements.txt ├── README.md ├── .gitignore └── example.md /src/animatediff/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/animatediff/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/animatediff/repo/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup() 4 | -------------------------------------------------------------------------------- /config/prompts/ignore_tokens.txt: -------------------------------------------------------------------------------- 1 | motion_blur 2 | blurry 3 | realistic 4 | depth_of_field 5 | -------------------------------------------------------------------------------- /config/prompts/to_8fps_Frames.bat: -------------------------------------------------------------------------------- 1 | ffmpeg -i %1 -start_number 0 -vf "scale=512:768,fps=8" %%04d.png -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # setuptools_scm will grab all tracked files, minus these exclusions 2 | prune .vscode 3 | -------------------------------------------------------------------------------- /src/animatediff/rife/__init__.py: -------------------------------------------------------------------------------- 1 | from .rife import app 2 | 3 | __all__ = [ 4 | "app", 5 | ] 6 | -------------------------------------------------------------------------------- /src/animatediff/__main__.py: -------------------------------------------------------------------------------- 1 | from animatediff.cli import cli 2 | 3 | if __name__ == "__main__": 4 | cli() 5 | -------------------------------------------------------------------------------- /config/prompts/concat_2horizontal.bat: -------------------------------------------------------------------------------- 1 | ffmpeg -i %1 -i %2 -filter_complex "[0:v][1:v]hstack=inputs=2[v]" -map "[v]" -crf 15 2horizontal.mp4 -------------------------------------------------------------------------------- /scripts/download/11-ToonYou.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://civitai.com/api/download/models/78775 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 3 | -------------------------------------------------------------------------------- /scripts/download/12-Lyriel.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://civitai.com/api/download/models/72396 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 3 | -------------------------------------------------------------------------------- /scripts/download/14-MajicMix.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://civitai.com/api/download/models/79068 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 3 | -------------------------------------------------------------------------------- /data/models/WD14tagger/model.onnx: -------------------------------------------------------------------------------- 1 | ../../../../../.cache/huggingface/hub/models--SmilingWolf--wd-v1-4-moat-tagger-v2/blobs/b8cef913be4c9e8d93f9f903e74271416502ce0b4b04df0ff1e2f00df488aa03 -------------------------------------------------------------------------------- /scripts/download/13-RcnzCartoon.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://civitai.com/api/download/models/71009 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 3 | -------------------------------------------------------------------------------- /scripts/download/15-RealisticVision.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://civitai.com/api/download/models/29460 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 3 | -------------------------------------------------------------------------------- /data/models/README.md: -------------------------------------------------------------------------------- 1 | ## Folder that contains the weight 2 | 3 | Put the weights of the base model by creating a new 'huggingface' folder and that of the motion module by creating a new 'motion-module' folder 4 | 5 | -------------------------------------------------------------------------------- /config/prompts/copy_png.bat: -------------------------------------------------------------------------------- 1 | 2 | setlocal enableDelayedExpansion 3 | FOR /l %%N in (1,1,%~n1) do ( 4 | set "n=00000%%N" 5 | set "TEST=!n:~-5! 6 | echo !TEST! 7 | copy /y %1 !TEST!.png 8 | ) 9 | 10 | ren %1 00000.png 11 | 12 | -------------------------------------------------------------------------------- /scripts/download/16-Tusun.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://civitai.com/api/download/models/97261 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 3 | wget https://civitai.com/api/download/models/50705 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 4 | -------------------------------------------------------------------------------- /scripts/download/17-FilmVelvia.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://civitai.com/api/download/models/90115 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 3 | wget https://civitai.com/api/download/models/92475 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 4 | -------------------------------------------------------------------------------- /scripts/download/18-GhibliBackground.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://civitai.com/api/download/models/102828 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 3 | wget https://civitai.com/api/download/models/57618 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate 4 | -------------------------------------------------------------------------------- /src/animatediff/ip_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .ip_adapter import (IPAdapter, IPAdapterFull, IPAdapterPlus, 2 | IPAdapterPlusXL, IPAdapterXL) 3 | 4 | __all__ = [ 5 | "IPAdapter", 6 | "IPAdapterPlus", 7 | "IPAdapterPlusXL", 8 | "IPAdapterXL", 9 | "IPAdapterFull", 10 | ] 11 | -------------------------------------------------------------------------------- /src/animatediff/softmax_splatting/correlation/README.md: -------------------------------------------------------------------------------- 1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. -------------------------------------------------------------------------------- /src/animatediff/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .animation import AnimationPipeline, AnimationPipelineOutput 2 | from .context import get_context_scheduler, get_total_steps, ordered_halving, uniform 3 | from .ti import get_text_embeddings, load_text_embeddings 4 | 5 | __all__ = [ 6 | "AnimationPipeline", 7 | "AnimationPipelineOutput", 8 | "get_context_scheduler", 9 | "get_total_steps", 10 | "ordered_halving", 11 | "uniform", 12 | "get_text_embeddings", 13 | "load_text_embeddings", 14 | ] 15 | -------------------------------------------------------------------------------- /scripts/download/03-BaseSD.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from diffusers.pipelines import StableDiffusionPipeline 3 | 4 | from animatediff import get_dir 5 | 6 | out_dir = get_dir("data/models/huggingface/stable-diffusion-v1-5") 7 | 8 | pipeline = StableDiffusionPipeline.from_pretrained( 9 | "runwayml/stable-diffusion-v1-5", 10 | use_safetensors=True, 11 | kwargs=dict(safety_checker=None, requires_safety_checker=False), 12 | ) 13 | pipeline.save_pretrained( 14 | save_directory=str(out_dir), 15 | safe_serialization=True, 16 | ) 17 | -------------------------------------------------------------------------------- /scripts/download/01-Motion-Modules.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Attempting download of Motion Module models from Google Drive." 4 | echo "If this fails, please download them manually from the links in the error messages/README." 5 | 6 | gdown 1RqkQuGPaCO5sGZ6V6KZ-jUWmsRu48Kdq -O models/motion-module/ || true 7 | gdown 1ql0g_Ys4UCz2RnokYlBjyOYPbttbIpbu -O models/motion-module/ || true 8 | 9 | echo "Motion module download script complete." 10 | echo "If you see errors above, please download the models manually from the links in the error messages/README." 11 | exit 0 12 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [*.{json,jsonc}] 18 | indent_style = space 19 | indent_size = 2 20 | 21 | [.vscode/*.{json,jsonc}] 22 | indent_style = space 23 | indent_size = 4 24 | 25 | [*.{yml,yaml,toml}] 26 | indent_style = space 27 | indent_size = 2 28 | 29 | [*.md] 30 | trim_trailing_whitespace = false 31 | 32 | [Makefile] 33 | indent_style = tab 34 | indent_size = 8 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = ["setuptools>=46.4.0", "wheel", "setuptools_scm[toml]>=6.2"] 4 | 5 | [tool.setuptools_scm] 6 | write_to = "src/animatediff/_version.py" 7 | 8 | [tool.black] 9 | line-length = 110 10 | target-version = ['py310'] 11 | ignore = ['F841', 'F401', 'E501'] 12 | preview = true 13 | 14 | [tool.ruff] 15 | line-length = 110 16 | target-version = 'py310' 17 | ignore = ['F841', 'F401', 'E501'] 18 | 19 | [tool.ruff.isort] 20 | combine-as-imports = true 21 | force-wrap-aliases = true 22 | known-local-folder = ["src"] 23 | known-first-party = ["animatediff"] 24 | 25 | [tool.pyright] 26 | include = ['src/**'] 27 | exclude = ['/usr/lib/**'] 28 | -------------------------------------------------------------------------------- /config/inference/motion_sdxl.json: -------------------------------------------------------------------------------- 1 | { 2 | "unet_additional_kwargs": { 3 | "unet_use_temporal_attention": false, 4 | "use_motion_module": true, 5 | "motion_module_resolutions": [1, 2, 4, 8], 6 | "motion_module_mid_block": false, 7 | "motion_module_type": "Vanilla", 8 | "motion_module_kwargs": { 9 | "num_attention_heads": 8, 10 | "num_transformer_block": 1, 11 | "attention_block_types": ["Temporal_Self", "Temporal_Self"], 12 | "temporal_position_encoding": true, 13 | "temporal_position_encoding_max_len": 32, 14 | "temporal_attention_dim_div": 1 15 | } 16 | }, 17 | "noise_scheduler_kwargs": { 18 | "num_train_timesteps": 1000, 19 | "beta_start": 0.00085, 20 | "beta_end": 0.020, 21 | "beta_schedule": "scaled_linear" 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | ci: 3 | autofix_prs: true 4 | autoupdate_branch: "main" 5 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" 6 | autoupdate_schedule: weekly 7 | 8 | repos: 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: "v0.0.281" 11 | hooks: 12 | - id: ruff 13 | args: ["--fix", "--exit-non-zero-on-fix"] 14 | 15 | - repo: https://github.com/psf/black 16 | rev: 23.7.0 17 | hooks: 18 | - id: black 19 | args: ["--line-length=110"] 20 | 21 | - repo: https://github.com/pre-commit/pre-commit-hooks 22 | rev: v4.4.0 23 | hooks: 24 | - id: trailing-whitespace 25 | args: [--markdown-linebreak-ext=md] 26 | - id: end-of-file-fixer 27 | - id: check-yaml 28 | - id: check-added-large-files 29 | -------------------------------------------------------------------------------- /config/inference/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "unet_additional_kwargs": { 3 | "unet_use_cross_frame_attention": false, 4 | "unet_use_temporal_attention": false, 5 | "use_motion_module": true, 6 | "motion_module_resolutions": [1, 2, 4, 8], 7 | "motion_module_mid_block": false, 8 | "motion_module_decoder_only": false, 9 | "motion_module_type": "Vanilla", 10 | "motion_module_kwargs": { 11 | "num_attention_heads": 8, 12 | "num_transformer_block": 1, 13 | "attention_block_types": ["Temporal_Self", "Temporal_Self"], 14 | "temporal_position_encoding": true, 15 | "temporal_position_encoding_max_len": 24, 16 | "temporal_attention_dim_div": 1 17 | } 18 | }, 19 | "noise_scheduler_kwargs": { 20 | "num_train_timesteps": 1000, 21 | "beta_start": 0.00085, 22 | "beta_end": 0.012, 23 | "beta_schedule": "linear", 24 | "steps_offset": 1, 25 | "clip_sample": false 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /config/inference/motion_v2.json: -------------------------------------------------------------------------------- 1 | { 2 | "unet_additional_kwargs": { 3 | "use_inflated_groupnorm": true, 4 | "unet_use_cross_frame_attention": false, 5 | "unet_use_temporal_attention": false, 6 | "use_motion_module": true, 7 | "motion_module_resolutions": [1, 2, 4, 8], 8 | "motion_module_mid_block": true, 9 | "motion_module_decoder_only": false, 10 | "motion_module_type": "Vanilla", 11 | "motion_module_kwargs": { 12 | "num_attention_heads": 8, 13 | "num_transformer_block": 1, 14 | "attention_block_types": ["Temporal_Self", "Temporal_Self"], 15 | "temporal_position_encoding": true, 16 | "temporal_position_encoding_max_len": 32, 17 | "temporal_attention_dim_div": 1 18 | } 19 | }, 20 | "noise_scheduler_kwargs": { 21 | "num_train_timesteps": 1000, 22 | "beta_start": 0.00085, 23 | "beta_end": 0.012, 24 | "beta_schedule": "linear", 25 | "steps_offset": 1, 26 | "clip_sample": false 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /scripts/test_persistent.py: -------------------------------------------------------------------------------- 1 | from rich import print 2 | 3 | from animatediff import get_dir 4 | from animatediff.cli import generate, logger 5 | 6 | config_dir = get_dir("config") 7 | 8 | config_path = config_dir.joinpath("prompts/test.json") 9 | width = 512 10 | height = 512 11 | length = 32 12 | context = 16 13 | stride = 4 14 | 15 | logger.warn("Running first-round generation test, this should load the full model.\n\n") 16 | out_dir = generate( 17 | config_path=config_path, 18 | width=width, 19 | height=height, 20 | length=length, 21 | context=context, 22 | stride=stride, 23 | ) 24 | logger.warn(f"Generated animation to {out_dir}") 25 | 26 | logger.warn("\n\nRunning second-round generation test, this should reuse the already loaded model.\n\n") 27 | out_dir = generate( 28 | config_path=config_path, 29 | width=width, 30 | height=height, 31 | length=length, 32 | context=context, 33 | stride=stride, 34 | ) 35 | logger.warn(f"Generated animation to {out_dir}") 36 | 37 | logger.error("If the second round didn't talk about reloading the model, it worked! yay!") 38 | -------------------------------------------------------------------------------- /scripts/download/sd-models.aria2: -------------------------------------------------------------------------------- 1 | https://civitai.com/api/download/models/78775 2 | out=models/sd/toonyou_beta3.safetensors 3 | https://civitai.com/api/download/models/72396 4 | out=models/sd/lyriel_v16.safetensors 5 | https://civitai.com/api/download/models/71009 6 | out=models/sd/rcnzCartoon3d_v10.safetensors 7 | https://civitai.com/api/download/models/79068 8 | out=majicmixRealistic_v5Preview.safetensors 9 | https://civitai.com/api/download/models/29460 10 | out=models/sd/realisticVisionV40_v20Novae.safetensors 11 | https://civitai.com/api/download/models/97261 12 | out=models/sd/TUSUN.safetensors 13 | https://civitai.com/api/download/models/50705 14 | out=models/sd/leosamsMoonfilm_reality20.safetensors 15 | https://civitai.com/api/download/models/90115 16 | out=models/sd/FilmVelvia2.safetensors 17 | https://civitai.com/api/download/models/92475 18 | out=models/sd/leosamsMoonfilm_filmGrain10.safetensors 19 | https://civitai.com/api/download/models/102828 20 | out=models/sd/Pyramid\ lora_Ghibli_n3.safetensors 21 | https://civitai.com/api/download/models/57618 22 | out=models/sd/CounterfeitV30_v30.safetensors 23 | -------------------------------------------------------------------------------- /config/GroundingDINO/GroundingDINO_SwinB_cfg.py: -------------------------------------------------------------------------------- 1 | batch_size = 1 2 | modelname = "groundingdino" 3 | backbone = "swin_B_384_22k" 4 | position_embedding = "sine" 5 | pe_temperatureH = 20 6 | pe_temperatureW = 20 7 | return_interm_indices = [1, 2, 3] 8 | backbone_freeze_keywords = None 9 | enc_layers = 6 10 | dec_layers = 6 11 | pre_norm = False 12 | dim_feedforward = 2048 13 | hidden_dim = 256 14 | dropout = 0.0 15 | nheads = 8 16 | num_queries = 900 17 | query_dim = 4 18 | num_patterns = 0 19 | num_feature_levels = 4 20 | enc_n_points = 4 21 | dec_n_points = 4 22 | two_stage_type = "standard" 23 | two_stage_bbox_embed_share = False 24 | two_stage_class_embed_share = False 25 | transformer_activation = "relu" 26 | dec_pred_bbox_embed_share = True 27 | dn_box_noise_scale = 1.0 28 | dn_label_noise_ratio = 0.5 29 | dn_label_coef = 1.0 30 | dn_bbox_coef = 1.0 31 | embed_init_tgt = True 32 | dn_labelbook_size = 2000 33 | max_text_len = 256 34 | text_encoder_type = "bert-base-uncased" 35 | use_text_enhancer = True 36 | use_fusion_layer = True 37 | use_checkpoint = True 38 | use_transformer_ckpt = True 39 | use_text_cross_attention = True 40 | text_dropout = 0.0 41 | fusion_dropout = 0.0 42 | fusion_droppath = 0.1 43 | sub_sentence_present = True 44 | -------------------------------------------------------------------------------- /config/GroundingDINO/GroundingDINO_SwinT_OGC.py: -------------------------------------------------------------------------------- 1 | batch_size = 1 2 | modelname = "groundingdino" 3 | backbone = "swin_T_224_1k" 4 | position_embedding = "sine" 5 | pe_temperatureH = 20 6 | pe_temperatureW = 20 7 | return_interm_indices = [1, 2, 3] 8 | backbone_freeze_keywords = None 9 | enc_layers = 6 10 | dec_layers = 6 11 | pre_norm = False 12 | dim_feedforward = 2048 13 | hidden_dim = 256 14 | dropout = 0.0 15 | nheads = 8 16 | num_queries = 900 17 | query_dim = 4 18 | num_patterns = 0 19 | num_feature_levels = 4 20 | enc_n_points = 4 21 | dec_n_points = 4 22 | two_stage_type = "standard" 23 | two_stage_bbox_embed_share = False 24 | two_stage_class_embed_share = False 25 | transformer_activation = "relu" 26 | dec_pred_bbox_embed_share = True 27 | dn_box_noise_scale = 1.0 28 | dn_label_noise_ratio = 0.5 29 | dn_label_coef = 1.0 30 | dn_bbox_coef = 1.0 31 | embed_init_tgt = True 32 | dn_labelbook_size = 2000 33 | max_text_len = 256 34 | text_encoder_type = "bert-base-uncased" 35 | use_text_enhancer = True 36 | use_fusion_layer = True 37 | use_checkpoint = True 38 | use_transformer_ckpt = True 39 | use_text_cross_attention = True 40 | text_dropout = 0.0 41 | fusion_dropout = 0.0 42 | fusion_droppath = 0.1 43 | sub_sentence_present = True 44 | -------------------------------------------------------------------------------- /config/prompts/01-ToonYou.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ToonYou", 3 | "base": "", 4 | "path": "models/sd/toonyou_beta3.safetensors", 5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt", 6 | "compile": false, 7 | "seed": [ 8 | 10788741199826055000, 6520604954829637000, 6519455744612556000, 9 | 16372571278361864000 10 | ], 11 | "scheduler": "k_dpmpp", 12 | "steps": 30, 13 | "guidance_scale": 8.5, 14 | "clip_skip": 2, 15 | "prompt": [ 16 | "1girl, solo, best quality, masterpiece, looking at viewer, purple hair, orange hair, gradient hair, blurry background, upper body, dress, flower print, spaghetti strap, bare shoulders", 17 | "1girl, solo, masterpiece, best quality, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,", 18 | "1girl, solo, best quality, masterpiece, looking at viewer, purple hair, orange hair, gradient hair, blurry background, upper body, dress, flower print, spaghetti strap, bare shoulders", 19 | "1girl, solo, best quality, masterpiece, cloudy sky, dandelion, contrapposto, alternate hairstyle" 20 | ], 21 | "n_prompt": [ 22 | "worst quality, low quality, cropped, lowres, text, jpeg artifacts, multiple view" 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /src/animatediff/utils/wild_card.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import re 5 | 6 | wild_card_regex = r'(\A|\W)__([\w-]+)__(\W|\Z)' 7 | 8 | 9 | def create_wild_card_map(wild_card_dir): 10 | result = {} 11 | if os.path.isdir(wild_card_dir): 12 | txt_list = glob.glob( os.path.join(wild_card_dir ,"**/*.txt"), recursive=True) 13 | for txt in txt_list: 14 | basename_without_ext = os.path.splitext(os.path.basename(txt))[0] 15 | with open(txt, encoding='utf-8') as f: 16 | try: 17 | result[basename_without_ext] = [s.rstrip() for s in f.readlines()] 18 | except Exception as e: 19 | print(e) 20 | print("can not read ", txt) 21 | return result 22 | 23 | def replace_wild_card_token(match_obj, wild_card_map): 24 | m1 = match_obj.group(1) 25 | m3 = match_obj.group(3) 26 | 27 | dict_name = match_obj.group(2) 28 | 29 | if dict_name in wild_card_map: 30 | token_list = wild_card_map[dict_name] 31 | token = token_list[random.randint(0,len(token_list)-1)] 32 | return m1+token+m3 33 | else: 34 | return match_obj.group(0) 35 | 36 | def replace_wild_card(prompt, wild_card_dir): 37 | wild_card_map = create_wild_card_map(wild_card_dir) 38 | prompt = re.sub(wild_card_regex, lambda x: replace_wild_card_token(x, wild_card_map ), prompt) 39 | return prompt 40 | -------------------------------------------------------------------------------- /config/prompts/04-MajicMix.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MajicMix", 3 | "base": "", 4 | "path": "models/sd/majicmixRealistic_v5Preview.safetensors", 5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt", 6 | "seed": [ 7 | 1572448948722921000, 1099474677988590700, 6488833139725636000, 8 | 18339859844376519000 9 | ], 10 | "scheduler": "k_dpmpp", 11 | "steps": 25, 12 | "guidance_scale": 7.5, 13 | "prompt": [ 14 | "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic", 15 | "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting", 16 | "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below", 17 | "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic" 18 | ], 19 | "n_prompt": [ 20 | "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles", 21 | "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome", 22 | "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome", 23 | "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /config/prompts/08-GhibliBackground.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "GhibliBackground", 3 | "base": "models/sd/CounterfeitV30_25.safetensors", 4 | "path": "models/sd/lora_Ghibli_n3.safetensors", 5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt", 6 | "seed": [ 7 | 8775748474469046000, 5893874876080607000, 11911465742147697000, 8 | 12437784838692000000 9 | ], 10 | "scheduler": "k_dpmpp", 11 | "steps": 25, 12 | "guidance_scale": 7.5, 13 | "lora_alpha": 1, 14 | "prompt": [ 15 | "best quality,single build,architecture, blue_sky, building,cloudy_sky, day, fantasy, fence, field, house, build,architecture,landscape, moss, outdoors, overgrown, path, river, road, rock, scenery, sky, sword, tower, tree, waterfall", 16 | "black_border, building, city, day, fantasy, ice, landscape, letterboxed, mountain, ocean, outdoors, planet, scenery, ship, snow, snowing, water, watercraft, waterfall, winter", 17 | ",mysterious sea area, fantasy,build,concept", 18 | "Tomb Raider,Scenography,Old building" 19 | ], 20 | "n_prompt": [ 21 | "easynegative,bad_construction,bad_structure,bad_wail,bad_windows,blurry,cloned_window,cropped,deformed,disfigured,error,extra_windows,extra_chimney,extra_door,extra_structure,extra_frame,fewer_digits,fused_structure,gross_proportions,jpeg_artifacts,long_roof,low_quality,structure_limbs,missing_windows,missing_doors,missing_roofs,mutated_structure,mutation,normal_quality,out_of_frame,owres,poorly_drawn_structure,poorly_drawn_house,signature,text,too_many_windows,ugly,username,uta,watermark,worst_quality" 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /scripts/download/02-All-SD-Models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | 4 | repo_dir=$(git rev-parse --show-toplevel) 5 | if [[ ! -d "${repo_dir}" ]]; then 6 | echo "Could not find the repo root. Checking for ./data/models/sd" 7 | repo_dir="." 8 | fi 9 | 10 | models_dir=$(realpath "${repo_dir}/data/models/sd") 11 | if [[ ! -d "${models_dir}" ]]; then 12 | echo "Could not find repo root or models directory." 13 | echo "Either create ./data/models/sd or run this script from a checked-out git repo." 14 | exit 1 15 | fi 16 | 17 | model_urls=( 18 | https://civitai.com/api/download/models/78775 # ToonYou 19 | https://civitai.com/api/download/models/72396 # Lyriel 20 | https://civitai.com/api/download/models/71009 # RcnzCartoon 21 | https://civitai.com/api/download/models/79068 # MajicMix 22 | https://civitai.com/api/download/models/29460 # RealisticVision 23 | https://civitai.com/api/download/models/97261 # Tusun (1/2) 24 | https://civitai.com/api/download/models/50705 # Tusun (2/2) 25 | https://civitai.com/api/download/models/90115 # FilmVelvia (1/2) 26 | https://civitai.com/api/download/models/92475 # FilmVelvia (2/2) 27 | https://civitai.com/api/download/models/102828 # GhibliBackground (1/2) 28 | https://civitai.com/api/download/models/57618 # GhibliBackground (2/2) 29 | ) 30 | 31 | echo "Downloading model files to ${models_dir}..." 32 | 33 | # Create the models directory if it doesn't exist 34 | mkdir -p "${models_dir}" 35 | 36 | # Download the models 37 | for url in ${model_urls[@]}; do 38 | curl -JLO --output-dir "${models_dir}" "${url}" || true 39 | done 40 | -------------------------------------------------------------------------------- /src/animatediff/dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | import cv2 3 | import numpy as np 4 | import onnxruntime as ort 5 | 6 | from .onnxdet import inference_detector 7 | from .onnxpose import inference_pose 8 | 9 | 10 | class Wholebody: 11 | def __init__(self, device='cuda:0'): 12 | providers = ['CPUExecutionProvider' 13 | ] if device == 'cpu' else ['CUDAExecutionProvider'] 14 | onnx_det = 'data/models/DWPose/yolox_l.onnx' 15 | onnx_pose = 'data/models/DWPose/dw-ll_ucoco_384.onnx' 16 | 17 | self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) 18 | self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) 19 | 20 | def __call__(self, oriImg): 21 | det_result = inference_detector(self.session_det, oriImg) 22 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) 23 | 24 | keypoints_info = np.concatenate( 25 | (keypoints, scores[..., None]), axis=-1) 26 | # compute neck joint 27 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 28 | # neck score when visualizing pred 29 | neck[:, 2:4] = np.logical_and( 30 | keypoints_info[:, 5, 2:4] > 0.3, 31 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 32 | new_keypoints_info = np.insert( 33 | keypoints_info, 17, neck, axis=1) 34 | mmpose_idx = [ 35 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 36 | ] 37 | openpose_idx = [ 38 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 39 | ] 40 | new_keypoints_info[:, openpose_idx] = \ 41 | new_keypoints_info[:, mmpose_idx] 42 | keypoints_info = new_keypoints_info 43 | 44 | keypoints, scores = keypoints_info[ 45 | ..., :2], keypoints_info[..., 2] 46 | 47 | return keypoints, scores 48 | 49 | 50 | -------------------------------------------------------------------------------- /src/animatediff/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import ( 3 | version as __version__, 4 | version_tuple, 5 | ) 6 | except ImportError: 7 | __version__ = "unknown (no version information available)" 8 | version_tuple = (0, 0, "unknown", "noinfo") 9 | 10 | from functools import lru_cache 11 | from os import getenv 12 | from pathlib import Path 13 | from warnings import filterwarnings 14 | 15 | from rich.console import Console 16 | from tqdm import TqdmExperimentalWarning 17 | 18 | PACKAGE = __package__.replace("_", "-") 19 | PACKAGE_ROOT = Path(__file__).parent.parent 20 | 21 | HF_HOME = Path(getenv("HF_HOME", Path.home() / ".cache" / "huggingface")) 22 | HF_HUB_CACHE = Path(getenv("HUGGINGFACE_HUB_CACHE", HF_HOME.joinpath("hub"))) 23 | 24 | HF_LIB_NAME = "animatediff-cli" 25 | HF_LIB_VER = __version__ 26 | HF_MODULE_REPO = "neggles/animatediff-modules" 27 | 28 | console = Console(highlight=True) 29 | err_console = Console(stderr=True) 30 | 31 | # shhh torch, don't worry about it it's fine 32 | filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") 33 | # you too tqdm 34 | filterwarnings("ignore", category=TqdmExperimentalWarning) 35 | 36 | 37 | @lru_cache(maxsize=4) 38 | def get_dir(dirname: str = "data") -> Path: 39 | if PACKAGE_ROOT.name == "src": 40 | # we're installed in editable mode from within the repo 41 | dirpath = PACKAGE_ROOT.parent.joinpath(dirname) 42 | else: 43 | # we're installed normally, so we just use the current working directory 44 | dirpath = Path.cwd().joinpath(dirname) 45 | dirpath.mkdir(parents=True, exist_ok=True) 46 | return dirpath.absolute() 47 | 48 | 49 | __all__ = [ 50 | "__version__", 51 | "version_tuple", 52 | "PACKAGE", 53 | "PACKAGE_ROOT", 54 | "HF_HOME", 55 | "HF_HUB_CACHE", 56 | "console", 57 | "err_console", 58 | "get_dir", 59 | "models", 60 | "pipelines", 61 | "rife", 62 | "utils", 63 | "cli", 64 | "generate", 65 | "schedulers", 66 | "settings", 67 | ] 68 | -------------------------------------------------------------------------------- /config/inference/sd15-unet3d.json: -------------------------------------------------------------------------------- 1 | { 2 | "sample_size": 64, 3 | "in_channels": 4, 4 | "out_channels": 4, 5 | "center_input_sample": false, 6 | "flip_sin_to_cos": true, 7 | "freq_shift": 0, 8 | "down_block_types": [ 9 | "CrossAttnDownBlock3D", 10 | "CrossAttnDownBlock3D", 11 | "CrossAttnDownBlock3D", 12 | "DownBlock3D" 13 | ], 14 | "mid_block_type": "UNetMidBlock3DCrossAttn", 15 | "up_block_types": [ 16 | "UpBlock3D", 17 | "CrossAttnUpBlock3D", 18 | "CrossAttnUpBlock3D", 19 | "CrossAttnUpBlock3D" 20 | ], 21 | "only_cross_attention": false, 22 | "block_out_channels": [320, 640, 1280, 1280], 23 | "layers_per_block": 2, 24 | "downsample_padding": 1, 25 | "mid_block_scale_factor": 1, 26 | "act_fn": "silu", 27 | "norm_num_groups": 32, 28 | "norm_eps": 1e-5, 29 | "cross_attention_dim": 768, 30 | "attention_head_dim": 8, 31 | "dual_cross_attention": false, 32 | "use_linear_projection": false, 33 | "class_embed_type": null, 34 | "num_class_embeds": null, 35 | "upcast_attention": false, 36 | "resnet_time_scale_shift": "default", 37 | "use_motion_module": true, 38 | "motion_module_resolutions": [1, 2, 4, 8], 39 | "motion_module_mid_block": false, 40 | "motion_module_decoder_only": false, 41 | "motion_module_type": "Vanilla", 42 | "motion_module_kwargs": { 43 | "num_attention_heads": 8, 44 | "num_transformer_block": 1, 45 | "attention_block_types": ["Temporal_Self", "Temporal_Self"], 46 | "temporal_position_encoding": true, 47 | "temporal_position_encoding_max_len": 24, 48 | "temporal_attention_dim_div": 1 49 | }, 50 | "unet_use_cross_frame_attention": false, 51 | "unet_use_temporal_attention": false, 52 | "_use_default_values": [ 53 | "use_linear_projection", 54 | "mid_block_type", 55 | "upcast_attention", 56 | "dual_cross_attention", 57 | "num_class_embeds", 58 | "only_cross_attention", 59 | "class_embed_type", 60 | "resnet_time_scale_shift" 61 | ], 62 | "_class_name": "UNet3DConditionModel", 63 | "_diffusers_version": "0.6.0" 64 | } 65 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = animatediff 3 | author = Andi Powers-Holmes 4 | email = aholmes@omnom.net 5 | maintainer = Andi Powers-Holmes 6 | maintainer_email = aholmes@omnom.net 7 | license_files = LICENSE.md 8 | 9 | [options] 10 | python_requires = >=3.10 11 | packages = find: 12 | package_dir = 13 | =src 14 | py_modules = 15 | animatediff 16 | include_package_data = True 17 | install_requires = 18 | accelerate >= 0.20.3 19 | colorama >= 0.4.3, < 0.5.0 20 | cmake >= 3.25.0 21 | diffusers == 0.23.0 22 | einops >= 0.6.1 23 | gdown >= 4.6.6 24 | ninja >= 1.11.0 25 | numpy >= 1.22.4 26 | omegaconf >= 2.3.0 27 | pillow >= 9.4.0, < 10.0.0 28 | pydantic >= 1.10.0, < 2.0.0 29 | rich >= 13.0.0, < 14.0.0 30 | safetensors >= 0.3.1 31 | sentencepiece >= 0.1.99 32 | shellingham >= 1.5.0, < 2.0.0 33 | torch >= 2.1.0, < 2.2.0 34 | torchaudio 35 | torchvision 36 | transformers >= 4.30.2, < 4.35.0 37 | typer >= 0.9.0, < 1.0.0 38 | controlnet_aux 39 | matplotlib 40 | ffmpeg-python >= 0.2.0 41 | mediapipe 42 | xformers >= 0.0.22.post7 43 | opencv-python 44 | 45 | [options.packages.find] 46 | where = src 47 | 48 | [options.package_data] 49 | * = *.txt, *.md 50 | 51 | [options.extras_require] 52 | dev = 53 | black >= 22.3.0 54 | ruff >= 0.0.234 55 | setuptools-scm >= 7.0.0 56 | pre-commit >= 3.3.0 57 | ipython 58 | rife = 59 | ffmpeg-python >= 0.2.0 60 | stylize = 61 | ffmpeg-python >= 0.2.0 62 | onnxruntime-gpu 63 | pandas 64 | opencv-python 65 | dwpose = 66 | onnxruntime-gpu 67 | stylize_mask = 68 | ffmpeg-python >= 0.2.0 69 | pandas 70 | segment-anything-hq == 0.3 71 | groundingdino-py == 0.4.0 72 | gitpython 73 | rembg[gpu] 74 | onnxruntime-gpu 75 | 76 | [options.entry_points] 77 | console_scripts = 78 | animatediff = animatediff.cli:cli 79 | 80 | [flake8] 81 | max-line-length = 110 82 | ignore = 83 | # these are annoying during development but should be enabled later 84 | F401 # module imported but unused 85 | F841 # local variable is assigned to but never used 86 | # black automatically fixes this 87 | E501 # line too long 88 | # black breaks these two rules: 89 | E203 # whitespace before : 90 | W503 # line break before binary operator 91 | extend-exclude = 92 | .venv 93 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.insertSpaces": true, 3 | "editor.tabSize": 4, 4 | "files.trimTrailingWhitespace": true, 5 | "editor.rulers": [100, 120], 6 | 7 | "files.associations": { 8 | "*.yaml": "yaml" 9 | }, 10 | 11 | "files.exclude": { 12 | "**/.git": true, 13 | "**/.svn": true, 14 | "**/.hg": true, 15 | "**/CVS": true, 16 | "**/.DS_Store": true, 17 | "**/Thumbs.db": true, 18 | "**/__pycache__": true 19 | }, 20 | 21 | "[python]": { 22 | "editor.wordBasedSuggestions": false, 23 | "editor.formatOnSave": true, 24 | "editor.defaultFormatter": "ms-python.black-formatter", 25 | "editor.codeActionsOnSave": { 26 | "source.organizeImports": true 27 | } 28 | }, 29 | "python.analysis.include": ["./src", "./scripts", "./tests"], 30 | 31 | "python.linting.enabled": false, 32 | "python.linting.pylintEnabled": false, 33 | "python.linting.flake8Enabled": true, 34 | "python.linting.flake8Args": ["--config=${workspaceFolder}/setup.cfg"], 35 | 36 | "[json]": { 37 | "editor.tabSize": 2, 38 | "editor.detectIndentation": false, 39 | "editor.formatOnSave": true, 40 | "editor.formatOnSaveMode": "file" 41 | }, 42 | 43 | "[toml]": { 44 | "editor.tabSize": 2, 45 | "editor.detectIndentation": false, 46 | "editor.formatOnSave": true, 47 | "editor.formatOnSaveMode": "file", 48 | "editor.defaultFormatter": "tamasfe.even-better-toml", 49 | "editor.rulers": [80, 100] 50 | }, 51 | "evenBetterToml.formatter.columnWidth": 88, 52 | 53 | "[yaml]": { 54 | "editor.detectIndentation": false, 55 | "editor.tabSize": 2, 56 | "editor.formatOnSave": true, 57 | "editor.formatOnSaveMode": "file" 58 | }, 59 | "yaml.format.bracketSpacing": true, 60 | "yaml.format.proseWrap": "preserve", 61 | "yaml.format.singleQuote": false, 62 | "yaml.format.printWidth": 110, 63 | 64 | "[markdown]": { 65 | "files.trimTrailingWhitespace": false 66 | }, 67 | 68 | "css.lint.validProperties": ["dock", "content-align", "content-justify"], 69 | "[css]": { 70 | "editor.formatOnSave": true 71 | }, 72 | 73 | "remote.autoForwardPorts": false, 74 | "remote.autoForwardPortsSource": "process" 75 | } 76 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import os 4 | import asyncio 5 | 6 | async def stylize(video): 7 | command = f"animatediff stylize create-config {video}" 8 | process = await asyncio.create_subprocess_shell( 9 | command, 10 | stdout=asyncio.subprocess.PIPE, 11 | stderr=asyncio.subprocess.PIPE 12 | ) 13 | stdout, stderr = await process.communicate() 14 | if process.returncode == 0: 15 | return stdout.decode() 16 | else: 17 | print(f"Error: {stderr.decode()}") 18 | 19 | async def start_video_edit(prompt_file): 20 | command = f"animatediff stylize generate {prompt_file}" 21 | process = await asyncio.create_subprocess_shell( 22 | command, 23 | stdout=asyncio.subprocess.PIPE, 24 | stderr=asyncio.subprocess.PIPE 25 | ) 26 | stdout, stderr = await process.communicate() 27 | if process.returncode == 0: 28 | return stdout.decode() 29 | else: 30 | print(f"Error: {stderr.decode()}") 31 | 32 | def edit_video(video, pos_prompt): 33 | x = asyncio.run(stylize(video)) 34 | x = x.split("stylize.py") 35 | config = x[18].split("config =")[-1].strip() 36 | d = x[19].split("stylize_dir = ")[-1].strip() 37 | 38 | with open(config, 'r+') as f: 39 | data = json.load(f) 40 | data['head_prompt'] = pos_prompt 41 | data["path"] = "models/huggingface/xxmix9realistic_v40.safetensors" 42 | 43 | os.remove(config) 44 | with open(config, 'w') as f: 45 | json.dump(data, f, indent=4) 46 | 47 | out = asyncio.run(start_video_edit(d)) 48 | out = out.split("Stylized results are output to ")[-1] 49 | out = out.split("stylize.py")[0].strip() 50 | 51 | cwd = os.getcwd() 52 | video_dir = cwd + "/" + out 53 | 54 | video_extensions = {'.mp4', '.avi', '.mkv', '.mov', '.flv', '.wmv'} 55 | video_path = None 56 | 57 | for dirpath, dirnames, filenames in os.walk(video_dir): 58 | for filename in filenames: 59 | if os.path.splitext(filename)[1].lower() in video_extensions: 60 | video_path = os.path.join(dirpath, filename) 61 | break 62 | if video_path: 63 | break 64 | 65 | return video_path 66 | 67 | video_path = input("Enter the path to your video: ") 68 | pos_prompt = input("Enter the what you want to do with the video: ") 69 | print("The video is stored at", edit_video(video_path, pos_prompt)) -------------------------------------------------------------------------------- /src/animatediff/utils/mask_rembg.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from rembg import new_session, remove 11 | from tqdm.rich import tqdm 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | def rembg_create_fg(frame_dir, output_dir, output_mask_dir, masked_area_list, 16 | bg_color=(0,255,0), 17 | mask_padding=0, 18 | ): 19 | 20 | frame_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False)) 21 | 22 | if mask_padding != 0: 23 | kernel = np.ones((abs(mask_padding),abs(mask_padding)),np.uint8) 24 | kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) 25 | 26 | session = new_session(providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 27 | 28 | for i, frame in tqdm(enumerate(frame_list),total=len(frame_list), desc=f"creating mask"): 29 | frame = Path(frame) 30 | file_name = frame.name 31 | 32 | cur_frame_no = int(frame.stem) 33 | 34 | img = Image.open(frame) 35 | img_array = np.asarray(img) 36 | 37 | mask_array = remove(img_array, only_mask=True, session=session) 38 | 39 | #mask_array = mask_array[None,...] 40 | 41 | if mask_padding < 0: 42 | mask_array = cv2.erode(mask_array.astype(np.uint8),kernel,iterations = 1) 43 | elif mask_padding > 0: 44 | mask_array = cv2.dilate(mask_array.astype(np.uint8),kernel,iterations = 1) 45 | 46 | mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_OPEN, kernel2) 47 | mask_array = cv2.GaussianBlur(mask_array, (7, 7), sigmaX=3, sigmaY=3, borderType=cv2.BORDER_DEFAULT) 48 | 49 | if masked_area_list[cur_frame_no] is not None: 50 | masked_area_list[cur_frame_no] = np.where(masked_area_list[cur_frame_no] > mask_array[None,...], masked_area_list[cur_frame_no], mask_array[None,...]) 51 | else: 52 | masked_area_list[cur_frame_no] = mask_array[None,...] 53 | 54 | if output_mask_dir: 55 | Image.fromarray(mask_array).save( output_mask_dir / file_name ) 56 | 57 | img_array = np.asarray(img).copy() 58 | if bg_color is not None: 59 | img_array[mask_array == 0] = bg_color 60 | 61 | img = Image.fromarray(img_array) 62 | 63 | img.save( output_dir / file_name ) 64 | 65 | return masked_area_list 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /config/prompts/03-RcnzCartoon.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RcnzCartoon", 3 | "base": "", 4 | "path": "models/sd/rcnzCartoon3d_v10.safetensors", 5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt", 6 | "seed": [ 7 | 16931037867122268000, 2094308009433392000, 4292543217695451000, 8 | 15572665120852310000 9 | ], 10 | "scheduler": "k_dpmpp", 11 | "steps": 25, 12 | "guidance_scale": 7.5, 13 | "prompt": [ 14 | "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded", 15 | "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face", 16 | "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes", 17 | "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering" 18 | ], 19 | "n_prompt": [ 20 | "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation", 21 | "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular", 22 | "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face,", 23 | "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /config/prompts/06-Tusun.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tusun", 3 | "base": "models/sd/moonfilm_reality20.safetensors", 4 | "path": "models/sd/TUSUN.safetensors", 5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt", 6 | "seed": [ 7 | 10154078483724687000, 2664393535095473700, 4231566096207623000, 8 | 1713349740448094500 9 | ], 10 | "scheduler": "k_dpmpp", 11 | "steps": 25, 12 | "guidance_scale": 7.5, 13 | "lora_alpha": 0.6, 14 | "prompt": [ 15 | "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing", 16 | "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing", 17 | "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing", 18 | "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body" 19 | ], 20 | "n_prompt": [ 21 | "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative" 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /config/prompts/05-RealisticVision.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RealisticVision", 3 | "base": "", 4 | "path": "models/sd/realisticVisionV20_v20.safetensors", 5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt", 6 | "seed": [ 7 | 5658137986800322000, 12099779162349365000, 10499524853910854000, 8 | 16768009035333712000 9 | ], 10 | "scheduler": "k_dpmpp", 11 | "steps": 25, 12 | "guidance_scale": 7.5, 13 | "prompt": [ 14 | "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", 15 | "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot", 16 | "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", 17 | "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" 18 | ], 19 | "n_prompt": [ 20 | "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 21 | "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 22 | "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation", 23 | "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /config/inference/sd15-unet.json: -------------------------------------------------------------------------------- 1 | { 2 | "sample_size": 64, 3 | "in_channels": 4, 4 | "out_channels": 4, 5 | "center_input_sample": false, 6 | "flip_sin_to_cos": true, 7 | "freq_shift": 0, 8 | "down_block_types": [ 9 | "CrossAttnDownBlock2D", 10 | "CrossAttnDownBlock2D", 11 | "CrossAttnDownBlock2D", 12 | "DownBlock2D" 13 | ], 14 | "mid_block_type": "UNetMidBlock2DCrossAttn", 15 | "up_block_types": [ 16 | "UpBlock2D", 17 | "CrossAttnUpBlock2D", 18 | "CrossAttnUpBlock2D", 19 | "CrossAttnUpBlock2D" 20 | ], 21 | "only_cross_attention": false, 22 | "block_out_channels": [320, 640, 1280, 1280], 23 | "layers_per_block": 2, 24 | "downsample_padding": 1, 25 | "mid_block_scale_factor": 1, 26 | "act_fn": "silu", 27 | "norm_num_groups": 32, 28 | "norm_eps": 1e-5, 29 | "cross_attention_dim": 768, 30 | "transformer_layers_per_block": 1, 31 | "encoder_hid_dim": null, 32 | "encoder_hid_dim_type": null, 33 | "attention_head_dim": 8, 34 | "num_attention_heads": null, 35 | "dual_cross_attention": false, 36 | "use_linear_projection": false, 37 | "class_embed_type": null, 38 | "addition_embed_type": null, 39 | "addition_time_embed_dim": null, 40 | "num_class_embeds": null, 41 | "upcast_attention": false, 42 | "resnet_time_scale_shift": "default", 43 | "resnet_skip_time_act": false, 44 | "resnet_out_scale_factor": 1.0, 45 | "time_embedding_type": "positional", 46 | "time_embedding_dim": null, 47 | "time_embedding_act_fn": null, 48 | "timestep_post_act": null, 49 | "time_cond_proj_dim": null, 50 | "conv_in_kernel": 3, 51 | "conv_out_kernel": 3, 52 | "projection_class_embeddings_input_dim": null, 53 | "class_embeddings_concat": false, 54 | "mid_block_only_cross_attention": null, 55 | "cross_attention_norm": null, 56 | "addition_embed_type_num_heads": 64, 57 | "_use_default_values": [ 58 | "transformer_layers_per_block", 59 | "use_linear_projection", 60 | "num_class_embeds", 61 | "addition_embed_type", 62 | "cross_attention_norm", 63 | "conv_out_kernel", 64 | "encoder_hid_dim_type", 65 | "projection_class_embeddings_input_dim", 66 | "num_attention_heads", 67 | "only_cross_attention", 68 | "class_embed_type", 69 | "resnet_time_scale_shift", 70 | "addition_embed_type_num_heads", 71 | "timestep_post_act", 72 | "mid_block_type", 73 | "mid_block_only_cross_attention", 74 | "time_embedding_type", 75 | "addition_time_embed_dim", 76 | "time_embedding_dim", 77 | "encoder_hid_dim", 78 | "resnet_skip_time_act", 79 | "conv_in_kernel", 80 | "upcast_attention", 81 | "dual_cross_attention", 82 | "resnet_out_scale_factor", 83 | "time_cond_proj_dim", 84 | "class_embeddings_concat", 85 | "time_embedding_act_fn" 86 | ], 87 | "_class_name": "UNet2DConditionModel", 88 | "_diffusers_version": "0.6.0" 89 | } 90 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import asyncio 4 | import gradio as gr 5 | 6 | async def stylize(video): 7 | command = f"animatediff stylize create-config {video}" 8 | process = await asyncio.create_subprocess_shell( 9 | command, 10 | stdout=asyncio.subprocess.PIPE, 11 | stderr=asyncio.subprocess.PIPE 12 | ) 13 | stdout, stderr = await process.communicate() 14 | if process.returncode == 0: 15 | return stdout.decode() 16 | else: 17 | print(f"Error: {stderr.decode()}") 18 | 19 | async def start_video_edit(prompt_file): 20 | command = f"animatediff stylize generate {prompt_file}" 21 | process = await asyncio.create_subprocess_shell( 22 | command, 23 | stdout=asyncio.subprocess.PIPE, 24 | stderr=asyncio.subprocess.PIPE 25 | ) 26 | stdout, stderr = await process.communicate() 27 | if process.returncode == 0: 28 | return stdout.decode() 29 | else: 30 | print(f"Error: {stderr.decode()}") 31 | 32 | def edit_video(video, pos_prompt): 33 | x = asyncio.run(stylize(video)) 34 | x = x.split("stylize.py") 35 | config = x[18].split("config =")[-1].strip() 36 | d = x[19].split("stylize_dir = ")[-1].strip() 37 | 38 | with open(config, 'r+') as f: 39 | data = json.load(f) 40 | data['head_prompt'] = pos_prompt 41 | data["path"] = "models/huggingface/xxmix9realistic_v40.safetensors" 42 | 43 | os.remove(config) 44 | with open(config, 'w') as f: 45 | json.dump(data, f, indent=4) 46 | 47 | out = asyncio.run(start_video_edit(d)) 48 | out = out.split("Stylized results are output to ")[-1] 49 | out = out.split("stylize.py")[0].strip() 50 | 51 | cwd = os.getcwd() 52 | video_dir = cwd + "/" + out 53 | 54 | video_extensions = {'.mp4', '.avi', '.mkv', '.mov', '.flv', '.wmv'} 55 | video_path = None 56 | 57 | for dirpath, dirnames, filenames in os.walk(video_dir): 58 | for filename in filenames: 59 | if os.path.splitext(filename)[1].lower() in video_extensions: 60 | video_path = os.path.join(dirpath, filename) 61 | break 62 | if video_path: 63 | break 64 | 65 | return video_path 66 | 67 | print("ready") 68 | 69 | with gr.Blocks() as interface: 70 | gr.Markdown("## Video Processor with Text Prompts") 71 | with gr.Row(): 72 | with gr.Column(): 73 | positive_prompt = gr.Textbox(label="Positive Prompt") 74 | video_input = gr.Video(label="Input Video") 75 | with gr.Column(): 76 | video_output = gr.Video(label="Processed Video") 77 | 78 | process_button = gr.Button("Process Video") 79 | process_button.click(fn=edit_video, 80 | inputs=[video_input, positive_prompt], 81 | outputs=video_output 82 | ) 83 | 84 | interface.launch() 85 | -------------------------------------------------------------------------------- /src/animatediff/dwpose/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | # Openpose 3 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose 4 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose 5 | # 3rd Edited by ControlNet 6 | # 4th Edited by ControlNet (added face and correct hands) 7 | 8 | import os 9 | 10 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 11 | import cv2 12 | import numpy as np 13 | import torch 14 | from controlnet_aux.util import HWC3, resize_image 15 | from PIL import Image 16 | 17 | from . import util 18 | from .wholebody import Wholebody 19 | 20 | 21 | def draw_pose(pose, H, W): 22 | bodies = pose['bodies'] 23 | faces = pose['faces'] 24 | hands = pose['hands'] 25 | candidate = bodies['candidate'] 26 | subset = bodies['subset'] 27 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 28 | 29 | canvas = util.draw_bodypose(canvas, candidate, subset) 30 | 31 | canvas = util.draw_handpose(canvas, hands) 32 | 33 | canvas = util.draw_facepose(canvas, faces) 34 | 35 | return canvas 36 | 37 | 38 | class DWposeDetector: 39 | def __init__(self): 40 | pass 41 | 42 | def to(self, device): 43 | self.pose_estimation = Wholebody(device) 44 | return self 45 | 46 | def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): 47 | input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR) 48 | 49 | input_image = HWC3(input_image) 50 | input_image = resize_image(input_image, detect_resolution) 51 | H, W, C = input_image.shape 52 | with torch.no_grad(): 53 | candidate, subset = self.pose_estimation(input_image) 54 | nums, keys, locs = candidate.shape 55 | candidate[..., 0] /= float(W) 56 | candidate[..., 1] /= float(H) 57 | body = candidate[:,:18].copy() 58 | body = body.reshape(nums*18, locs) 59 | score = subset[:,:18] 60 | for i in range(len(score)): 61 | for j in range(len(score[i])): 62 | if score[i][j] > 0.3: 63 | score[i][j] = int(18*i+j) 64 | else: 65 | score[i][j] = -1 66 | 67 | un_visible = subset<0.3 68 | candidate[un_visible] = -1 69 | 70 | foot = candidate[:,18:24] 71 | 72 | faces = candidate[:,24:92] 73 | 74 | hands = candidate[:,92:113] 75 | hands = np.vstack([hands, candidate[:,113:]]) 76 | 77 | bodies = dict(candidate=body, subset=score) 78 | pose = dict(bodies=bodies, hands=hands, faces=faces) 79 | 80 | detected_map = draw_pose(pose, H, W) 81 | detected_map = HWC3(detected_map) 82 | 83 | img = resize_image(input_image, image_resolution) 84 | H, W, C = img.shape 85 | 86 | detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) 87 | 88 | if output_type == "pil": 89 | detected_map = Image.fromarray(detected_map) 90 | 91 | return detected_map 92 | -------------------------------------------------------------------------------- /config/prompts/07-FilmVelvia.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "FilmVelvia", 3 | "base": "models/sd/majicmixRealistic_v4.safetensors", 4 | "path": "models/sd/FilmVelvia2.safetensors", 5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt", 6 | "seed": [ 7 | 358675358833372800, 3519455280971924000, 11684545350557985000, 8 | 8696855302100400000 9 | ], 10 | "scheduler": "k_dpmpp", 11 | "steps": 25, 12 | "guidance_scale": 7.5, 13 | "lora_alpha": 0.6, 14 | "prompt": [ 15 | "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name", 16 | ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir", 17 | "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark", 18 | "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, " 19 | ], 20 | "n_prompt": [ 21 | "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg", 22 | "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg", 23 | "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg", 24 | "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /config/prompts/02-Lyriel.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Lyriel", 3 | "base": "", 4 | "path": "models/sd/lyriel_v16.safetensors", 5 | "motion_module": "models/motion-module/mm_sd_v15.ckpt", 6 | "seed": [ 7 | 10917152860782582000, 6399018107401806000, 15875751942533906000, 8 | 6653196880059937000 9 | ], 10 | "scheduler": "k_dpmpp", 11 | "steps": 25, 12 | "guidance_scale": 7.5, 13 | "prompt": [ 14 | "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange", 15 | "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal", 16 | "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray", 17 | "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown." 18 | ], 19 | "n_prompt": [ 20 | "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration", 21 | "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular", 22 | "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome", 23 | "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /src/animatediff/utils/mask_animseg.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | import onnxruntime as rt 9 | import torch 10 | from PIL import Image 11 | from rembg import new_session, remove 12 | from tqdm.rich import tqdm 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | def animseg_create_fg(frame_dir, output_dir, output_mask_dir, masked_area_list, 17 | bg_color=(0,255,0), 18 | mask_padding=0, 19 | ): 20 | 21 | frame_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False)) 22 | 23 | if mask_padding != 0: 24 | kernel = np.ones((abs(mask_padding),abs(mask_padding)),np.uint8) 25 | kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) 26 | 27 | 28 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 29 | rmbg_model = rt.InferenceSession("data/models/anime_seg/isnetis.onnx", providers=providers) 30 | 31 | def get_mask(img, s=1024): 32 | img = (img / 255).astype(np.float32) 33 | h, w = h0, w0 = img.shape[:-1] 34 | h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s) 35 | ph, pw = s - h, s - w 36 | img_input = np.zeros([s, s, 3], dtype=np.float32) 37 | img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h)) 38 | img_input = np.transpose(img_input, (2, 0, 1)) 39 | img_input = img_input[np.newaxis, :] 40 | mask = rmbg_model.run(None, {'img': img_input})[0][0] 41 | mask = np.transpose(mask, (1, 2, 0)) 42 | mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] 43 | mask = cv2.resize(mask, (w0, h0)) 44 | mask = (mask * 255).astype(np.uint8) 45 | return mask 46 | 47 | 48 | for i, frame in tqdm(enumerate(frame_list),total=len(frame_list), desc=f"creating mask"): 49 | frame = Path(frame) 50 | file_name = frame.name 51 | 52 | cur_frame_no = int(frame.stem) 53 | 54 | img = Image.open(frame) 55 | img_array = np.asarray(img) 56 | 57 | mask_array = get_mask(img_array) 58 | 59 | # Image.fromarray(mask_array).save( output_dir / Path("raw_" + file_name)) 60 | 61 | if mask_padding < 0: 62 | mask_array = cv2.erode(mask_array.astype(np.uint8),kernel,iterations = 1) 63 | elif mask_padding > 0: 64 | mask_array = cv2.dilate(mask_array.astype(np.uint8),kernel,iterations = 1) 65 | 66 | mask_array = cv2.morphologyEx(mask_array, cv2.MORPH_OPEN, kernel2) 67 | mask_array = cv2.GaussianBlur(mask_array, (7, 7), sigmaX=3, sigmaY=3, borderType=cv2.BORDER_DEFAULT) 68 | 69 | if masked_area_list[cur_frame_no] is not None: 70 | masked_area_list[cur_frame_no] = np.where(masked_area_list[cur_frame_no] > mask_array[None,...], masked_area_list[cur_frame_no], mask_array[None,...]) 71 | else: 72 | masked_area_list[cur_frame_no] = mask_array[None,...] 73 | 74 | if output_mask_dir: 75 | Image.fromarray(mask_array).save( output_mask_dir / file_name ) 76 | 77 | img_array = np.asarray(img).copy() 78 | if bg_color is not None: 79 | img_array[mask_array == 0] = bg_color 80 | 81 | img = Image.fromarray(img_array) 82 | 83 | img.save( output_dir / file_name ) 84 | 85 | return masked_area_list 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.30.1 3 | aiofiles==23.2.1 4 | aiohttp==3.9.5 5 | aiosignal==1.3.1 6 | altair==5.3.0 7 | analytics-python==1.4.post1 8 | -e git+https://github.com/TheNetherWatcher/Vid2Vid-using-Text-prompt.git@213aa1fc330895557d619514bc778d0985a3112a#egg=animatediff 9 | annotated-types==0.7.0 10 | antlr4-python3-runtime==4.9.3 11 | anyio==4.4.0 12 | async-timeout==4.0.3 13 | attrs==23.2.0 14 | backoff==1.10.0 15 | bcrypt==4.1.3 16 | beautifulsoup4==4.12.3 17 | blinker==1.8.2 18 | certifi==2024.2.2 19 | cffi==1.16.0 20 | charset-normalizer==3.3.2 21 | click==8.1.7 22 | cmake==3.29.3 23 | colorama==0.4.6 24 | coloredlogs==15.0.1 25 | contourpy==1.2.1 26 | controlnet-aux==0.0.9 27 | cryptography==42.0.7 28 | cycler==0.12.1 29 | diffusers==0.23.0 30 | dnspython==2.6.1 31 | einops==0.8.0 32 | email_validator==2.1.1 33 | exceptiongroup==1.2.1 34 | fastapi==0.111.0 35 | fastapi-cli==0.0.4 36 | ffmpeg-python==0.2.0 37 | ffmpy==0.3.2 38 | filelock==3.14.0 39 | Flask==3.0.3 40 | Flask-CacheBuster==1.0.0 41 | Flask-Cors==4.0.1 42 | Flask-Login==0.6.3 43 | flatbuffers==24.3.25 44 | fonttools==4.53.0 45 | frozenlist==1.4.1 46 | fsspec==2024.5.0 47 | future==1.0.0 48 | gdown==5.2.0 49 | gradio==3.0 50 | gradio_client==0.7.0 51 | h11==0.14.0 52 | httpcore==1.0.5 53 | httptools==0.6.1 54 | httpx==0.27.0 55 | huggingface-hub==0.17.3 56 | humanfriendly==10.0 57 | idna==3.7 58 | imageio==2.34.1 59 | importlib_metadata==7.1.0 60 | importlib_resources==6.4.0 61 | itsdangerous==2.2.0 62 | jax==0.4.28 63 | jaxlib==0.4.28 64 | Jinja2==3.1.4 65 | jsonschema==4.22.0 66 | jsonschema-specifications==2023.12.1 67 | kiwisolver==1.4.5 68 | lazy_loader==0.4 69 | linkify-it-py==2.0.3 70 | markdown-it-py==3.0.0 71 | markdown2==2.4.13 72 | MarkupSafe==2.1.5 73 | matplotlib==3.9.0 74 | mdit-py-plugins==0.4.1 75 | mdurl==0.1.2 76 | mediapipe==0.10.14 77 | ml-dtypes==0.4.0 78 | monotonic==1.6 79 | mpmath==1.3.0 80 | multidict==6.0.5 81 | networkx==3.3 82 | ninja==1.11.1.1 83 | numpy==1.26.4 84 | omegaconf==2.3.0 85 | opencv-contrib-python==4.9.0.80 86 | opencv-python==4.9.0.80 87 | opencv-python-headless==4.9.0.80 88 | opt-einsum==3.3.0 89 | orjson==3.10.3 90 | packaging==24.0 91 | pandas==2.2.2 92 | paramiko==3.4.0 93 | Pillow==9.5.0 94 | protobuf==4.25.3 95 | psutil==5.9.8 96 | pycparser==2.22 97 | pycryptodome==3.20.0 98 | pydantic==1.10.15 99 | pydantic_core==2.18.3 100 | pydub==0.25.1 101 | Pygments==2.18.0 102 | PyNaCl==1.5.0 103 | pyparsing==3.1.2 104 | PySocks==1.7.1 105 | python-dateutil==2.9.0.post0 106 | python-dotenv==1.0.1 107 | python-multipart==0.0.9 108 | pytz==2024.1 109 | PyYAML==6.0.1 110 | referencing==0.35.1 111 | regex==2024.5.15 112 | requests==2.32.3 113 | rich==13.7.1 114 | rpds-py==0.18.1 115 | ruff==0.4.7 116 | safetensors==0.4.3 117 | scikit-image==0.23.2 118 | scipy==1.13.1 119 | semantic-version==2.10.0 120 | sentencepiece==0.2.0 121 | shellingham==1.5.4 122 | six==1.16.0 123 | sniffio==1.3.1 124 | sounddevice==0.4.7 125 | soupsieve==2.5 126 | starlette==0.37.2 127 | sympy==1.12.1 128 | tifffile==2024.5.22 129 | timm==0.6.7 130 | tokenizers==0.14.1 131 | tomlkit==0.12.0 132 | toolz==0.12.1 133 | torch==2.1.2 134 | torchaudio==2.1.2 135 | torchvision==0.16.2 136 | tqdm==4.66.4 137 | transformers==4.34.1 138 | triton==2.1.0 139 | typer==0.12.3 140 | typing_extensions==4.12.0 141 | tzdata==2024.1 142 | uc-micro-py==1.0.3 143 | ujson==5.10.0 144 | urllib3==2.2.1 145 | uvicorn==0.30.1 146 | uvloop==0.19.0 147 | watchfiles==0.22.0 148 | websockets==11.0.3 149 | Werkzeug==3.0.3 150 | xformers==0.0.23.post1 151 | yarl==1.9.4 152 | zipp==3.19.1 -------------------------------------------------------------------------------- /src/animatediff/pipelines/context.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import numpy as np 4 | 5 | 6 | # Whatever this is, it's utterly cursed. 7 | def ordered_halving(val): 8 | bin_str = f"{val:064b}" 9 | bin_flip = bin_str[::-1] 10 | as_int = int(bin_flip, 2) 11 | 12 | return as_int / (1 << 64) 13 | 14 | 15 | # I have absolutely no idea how this works and I don't like that. 16 | def uniform( 17 | step: int = ..., 18 | num_steps: Optional[int] = None, 19 | num_frames: int = ..., 20 | context_size: Optional[int] = None, 21 | context_stride: int = 3, 22 | context_overlap: int = 4, 23 | closed_loop: bool = True, 24 | ): 25 | if num_frames <= context_size: 26 | yield list(range(num_frames)) 27 | return 28 | 29 | context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) 30 | 31 | for context_step in 1 << np.arange(context_stride): 32 | pad = int(round(num_frames * ordered_halving(step))) 33 | for j in range( 34 | int(ordered_halving(step) * context_step) + pad, 35 | num_frames + pad + (0 if closed_loop else -context_overlap), 36 | (context_size * context_step - context_overlap), 37 | ): 38 | yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] 39 | 40 | 41 | def shuffle( 42 | step: int = ..., 43 | num_steps: Optional[int] = None, 44 | num_frames: int = ..., 45 | context_size: Optional[int] = None, 46 | context_stride: int = 3, 47 | context_overlap: int = 4, 48 | closed_loop: bool = True, 49 | ): 50 | import random 51 | c = list(range(num_frames)) 52 | c = random.sample(c, len(c)) 53 | 54 | if len(c) % context_size: 55 | c += c[0:context_size - len(c) % context_size] 56 | 57 | c = random.sample(c, len(c)) 58 | 59 | for i in range(0, len(c), context_size): 60 | yield c[i:i+context_size] 61 | 62 | 63 | def composite( 64 | step: int = ..., 65 | num_steps: Optional[int] = None, 66 | num_frames: int = ..., 67 | context_size: Optional[int] = None, 68 | context_stride: int = 3, 69 | context_overlap: int = 4, 70 | closed_loop: bool = True, 71 | ): 72 | if (step/num_steps) < 0.1: 73 | return shuffle(step,num_steps,num_frames,context_size,context_stride,context_overlap,closed_loop) 74 | else: 75 | return uniform(step,num_steps,num_frames,context_size,context_stride,context_overlap,closed_loop) 76 | 77 | 78 | def get_context_scheduler(name: str) -> Callable: 79 | match name: 80 | case "uniform": 81 | return uniform 82 | case "shuffle": 83 | return shuffle 84 | case "composite": 85 | return composite 86 | case _: 87 | raise ValueError(f"Unknown context_overlap policy {name}") 88 | 89 | 90 | def get_total_steps( 91 | scheduler, 92 | timesteps: list[int], 93 | num_steps: Optional[int] = None, 94 | num_frames: int = ..., 95 | context_size: Optional[int] = None, 96 | context_stride: int = 3, 97 | context_overlap: int = 4, 98 | closed_loop: bool = True, 99 | ): 100 | return sum( 101 | len( 102 | list( 103 | scheduler( 104 | i, 105 | num_steps, 106 | num_frames, 107 | context_size, 108 | context_stride, 109 | context_overlap, 110 | ) 111 | ) 112 | ) 113 | for i in range(len(timesteps)) 114 | ) 115 | -------------------------------------------------------------------------------- /src/animatediff/rife/ncnn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class RifeNCNNOptions(BaseModel): 11 | model_path: Path = Field(..., description="Path to RIFE model directory") 12 | input_path: Path = Field(..., description="Path to source frames directory") 13 | output_path: Optional[Path] = Field(None, description="Path to output frames directory") 14 | num_frame: Optional[int] = Field(None, description="Number of frames to generate (default N*2)") 15 | time_step: float = Field(0.5, description="Time step for interpolation (default 0.5)", gt=0.0, le=1.0) 16 | gpu_id: Optional[int | list[int]] = Field( 17 | None, description="GPU ID(s) to use (default: auto, -1 for CPU)" 18 | ) 19 | load_threads: int = Field(1, description="Number of threads for frame loading", gt=0) 20 | process_threads: int = Field(2, description="Number of threads used for frame processing", gt=0) 21 | save_threads: int = Field(2, description="Number of threads for frame saving", gt=0) 22 | spatial_tta: bool = Field(False, description="Enable spatial TTA mode") 23 | temporal_tta: bool = Field(False, description="Enable temporal TTA mode") 24 | uhd: bool = Field(False, description="Enable UHD mode") 25 | verbose: bool = Field(False, description="Enable verbose logging") 26 | 27 | def get_args(self, frame_multiplier: int = 7) -> list[str]: 28 | """Generate arguments to pass to rife-ncnn-vulkan. 29 | 30 | Frame multiplier is used to calculate the number of frames to generate, if num_frame is not set. 31 | """ 32 | if self.output_path is None: 33 | self.output_path = self.input_path.joinpath("out") 34 | 35 | # calc num frames 36 | if self.num_frame is None: 37 | num_src_frames = len([x for x in self.input_path.glob("*.png") if x.is_file()]) 38 | logger.info(f"Found {num_src_frames} source frames, using multiplier {frame_multiplier}") 39 | num_frame = num_src_frames * frame_multiplier 40 | logger.info(f"We will generate {num_frame} frames") 41 | else: 42 | num_frame = self.num_frame 43 | 44 | # GPU ID and process threads are comma-separated lists, so we need to convert them to strings 45 | if self.gpu_id is None: 46 | gpu_id = "auto" 47 | process_threads = self.process_threads 48 | elif isinstance(self.gpu_id, list): 49 | gpu_id = ",".join([str(x) for x in self.gpu_id]) 50 | process_threads = ",".join([str(self.process_threads) for _ in self.gpu_id]) 51 | else: 52 | gpu_id = str(self.gpu_id) 53 | process_threads = str(self.process_threads) 54 | 55 | # Build args list 56 | args_list = [ 57 | "-i", 58 | f"{self.input_path.resolve()}/", 59 | "-o", 60 | f"{self.output_path.resolve()}/", 61 | "-m", 62 | f"{self.model_path.resolve()}/", 63 | "-n", 64 | num_frame, 65 | "-s", 66 | f"{self.time_step:.5f}", 67 | "-g", 68 | gpu_id, 69 | "-j", 70 | f"{self.load_threads}:{process_threads}:{self.save_threads}", 71 | ] 72 | 73 | # Add flags if set 74 | if self.spatial_tta: 75 | args_list.append("-x") 76 | if self.temporal_tta: 77 | args_list.append("-z") 78 | if self.uhd: 79 | args_list.append("-u") 80 | if self.verbose: 81 | args_list.append("-v") 82 | 83 | # Convert all args to strings and return 84 | return [str(x) for x in args_list] 85 | -------------------------------------------------------------------------------- /src/animatediff/schedulers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from enum import Enum 3 | 4 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, 5 | DPMSolverSinglestepScheduler, 6 | EulerAncestralDiscreteScheduler, 7 | EulerDiscreteScheduler, 8 | HeunDiscreteScheduler, 9 | KDPM2AncestralDiscreteScheduler, 10 | KDPM2DiscreteScheduler, LCMScheduler, 11 | LMSDiscreteScheduler, PNDMScheduler, 12 | UniPCMultistepScheduler) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | # See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111 18 | class DiffusionScheduler(str, Enum): 19 | lcm = "lcm" # LCM 20 | ddim = "ddim" # DDIM 21 | pndm = "pndm" # PNDM 22 | heun = "heun" # Heun 23 | unipc = "unipc" # UniPC 24 | euler = "euler" # Euler 25 | euler_a = "euler_a" # Euler a 26 | 27 | lms = "lms" # LMS 28 | k_lms = "k_lms" # LMS Karras 29 | 30 | dpm_2 = "dpm_2" # DPM2 31 | k_dpm_2 = "k_dpm_2" # DPM2 Karras 32 | 33 | dpm_2_a = "dpm_2_a" # DPM2 a 34 | k_dpm_2_a = "k_dpm_2_a" # DPM2 a Karras 35 | 36 | dpmpp_2m = "dpmpp_2m" # DPM++ 2M 37 | k_dpmpp_2m = "k_dpmpp_2m" # DPM++ 2M Karras 38 | 39 | dpmpp_sde = "dpmpp_sde" # DPM++ SDE 40 | k_dpmpp_sde = "k_dpmpp_sde" # DPM++ SDE Karras 41 | 42 | dpmpp_2m_sde = "dpmpp_2m_sde" # DPM++ 2M SDE 43 | k_dpmpp_2m_sde = "k_dpmpp_2m_sde" # DPM++ 2M SDE Karras 44 | 45 | 46 | def get_scheduler(name: str, config: dict = {}): 47 | is_karras = name.startswith("k_") 48 | if is_karras: 49 | # strip the k_ prefix and add the karras sigma flag to config 50 | name = name.lstrip("k_") 51 | config["use_karras_sigmas"] = True 52 | 53 | match name: 54 | case DiffusionScheduler.lcm: 55 | sched_class = LCMScheduler 56 | case DiffusionScheduler.ddim: 57 | sched_class = DDIMScheduler 58 | case DiffusionScheduler.pndm: 59 | sched_class = PNDMScheduler 60 | case DiffusionScheduler.heun: 61 | sched_class = HeunDiscreteScheduler 62 | case DiffusionScheduler.unipc: 63 | sched_class = UniPCMultistepScheduler 64 | case DiffusionScheduler.euler: 65 | sched_class = EulerDiscreteScheduler 66 | case DiffusionScheduler.euler_a: 67 | sched_class = EulerAncestralDiscreteScheduler 68 | case DiffusionScheduler.lms: 69 | sched_class = LMSDiscreteScheduler 70 | case DiffusionScheduler.dpm_2: 71 | # Equivalent to DPM2 in K-Diffusion 72 | sched_class = KDPM2DiscreteScheduler 73 | case DiffusionScheduler.dpm_2_a: 74 | # Equivalent to `DPM2 a`` in K-Diffusion 75 | sched_class = KDPM2AncestralDiscreteScheduler 76 | case DiffusionScheduler.dpmpp_2m: 77 | # Equivalent to `DPM++ 2M` in K-Diffusion 78 | sched_class = DPMSolverMultistepScheduler 79 | config["algorithm_type"] = "dpmsolver++" 80 | config["solver_order"] = 2 81 | case DiffusionScheduler.dpmpp_sde: 82 | # Equivalent to `DPM++ SDE` in K-Diffusion 83 | sched_class = DPMSolverSinglestepScheduler 84 | case DiffusionScheduler.dpmpp_2m_sde: 85 | # Equivalent to `DPM++ 2M SDE` in K-Diffusion 86 | sched_class = DPMSolverMultistepScheduler 87 | config["algorithm_type"] = "sde-dpmsolver++" 88 | case _: 89 | raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'") 90 | 91 | return sched_class.from_config(config) 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automatic TikTok generator 2 | 3 | This repository contains a pipeline for video-to-video generation using text prompts. The system leverages AnimateDiff and OpenPose ControlNet for pose estimation, and incorporates a prompt traveling method for improved coherence between the original and generated videos. Users can interact with this pipeline through a Gradio app or a standard Python program. 4 | 5 | ## Techniques used 6 | 7 | - **AnimateDiff**: Utilized for generating high-quality animations based on text prompts and an image as an input. 8 | - **OpenPose ControlNet**: Used for accurate pose estimation to guide the animation process. 9 | - **Prompt Traveling Method**: Ensures better relativeness and coherence between the input video and the generated output. 10 | - **User Interfaces**: 11 | - **Gradio App**: An intuitive web-based interface for easy interaction. 12 | - **Python Program**: A script-based interface for users preferring command-line interaction. 13 | 14 | ### Base models 15 | 16 | - [XXMix_9realistic](https://civitai.com/models/47274): Model used for generating life-like video (Recommended for life-like video) 17 | - [Mistoon_Anime](https://civitai.com/models/24149/mistoonanime): Model used for generating anime-like video (Recommended for anime-like video) 18 | 19 | ### Motion modules 20 | 21 | - [mm_sd_v15_v2](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt): Motion module used for generating segments of the final from the generated images (Recommended) 22 | - [mm_sd_v15](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15.ckpt) and [mm_sd_v14](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v14.ckpt) are some other modules that can be also used. 23 | 24 | ### ControlNets 25 | 26 | - [control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_openpose.pth): ControlNet for pose estimation from the given video 27 | - Upcoming support for depth and canny controlnets too for better generated video quality. 28 | 29 | ### Prompt Travelling 30 | 31 | This is a technique that is used to give the model, instruction at which frame what to do with the output image. 32 | For example, if in the prompt body it is written like, 30 - face: up, camera: zoomed out, right-hand: waving, then in the output 30th frame, the image will be generated according to the given prompt. 33 | 34 | ## Installation 35 | 36 | To set up the environment and install the necessary dependencies, follow these steps: 37 | 38 | 1. **Clone the repository:** 39 | 40 | ```bash 41 | git clone https://github.com/TheNetherWatcher/Vid2Vid-using-Text-prompt.git 42 | cd Vid2Vid-using-Text-prompt 43 | ``` 44 | 45 | 2. **Create and activate a virtual environment:** 46 | 47 | ```bash 48 | python -m venv venv 49 | source venv/bin/activate # On Windows, use `venv\Scripts\activate` 50 | ``` 51 | 52 | 3. **Install the required packages:** 53 | 54 | ```bash 55 | pip install -e . 56 | pip install -e .[stylize] 57 | ``` 58 | 59 | ## Usage 60 | 61 | ### Model weights 62 | 63 | - Download the model weights from the abve links or another, and put them [here](./data/models/huggingface), and for the downloaded motion modules, put them [here](data/models/motion-module) 64 | - For the first time, you might get errors like model weights not found, just go to stylize directory and in the most recently created folder, edit the model name in the prompt.json file. Support for this is also under development. 65 | 66 | ### Gradio App 67 | 68 | To run the Gradio app, execute the following command: 69 | 70 | ```bash 71 | python app.py 72 | ``` 73 | 74 | The gradio app provides a interface for uploading video and providing a text prompt as a input and outputs the generated video. 75 | 76 | ### Commandline 77 | 78 | ```bash 79 | python test.py 80 | ``` 81 | 82 | After running this, you will be prompted to enter the location of the video, positive prompt (the changes that you want to make in the video), and a negative prompt. 83 | Negative prompt is set to a default value, but you can edit it if you like. 84 | 85 | ## Upcoming Dedvelopments 86 | 87 | - LoRA support, and controlnet(like canny, depth, edge) support 88 | - Gradio app support for using different controlnets and LoRAs 89 | - CLI options for controlling the execution in different system 90 | 91 | ## Credits 92 | 93 | - [AnimateDiff](https://github.com/guoyww/AnimateDiff) 94 | - [Prompt Travelling using AnimateDiff](https://github.com/s9roll7/animatediff-cli-prompt-travel) 95 | -------------------------------------------------------------------------------- /src/animatediff/utils/civitai2config.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import shutil 7 | from pathlib import Path 8 | 9 | from animatediff import get_dir 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | data_dir = get_dir("data") 14 | 15 | extra_loading_regex = r'(<[^>]+?>)' 16 | 17 | def generate_config_from_civitai_info( 18 | lora_dir:Path, 19 | config_org:Path, 20 | out_dir:Path, 21 | lora_weight:float, 22 | ): 23 | lora_abs_dir = lora_dir.absolute() 24 | config_org = config_org.absolute() 25 | out_dir = out_dir.absolute() 26 | 27 | civitais = sorted(glob.glob( os.path.join(lora_abs_dir, "*.civitai.info"), recursive=False)) 28 | 29 | with open(config_org, "r") as cf: 30 | org_config = json.load(cf) 31 | 32 | for civ in civitais: 33 | 34 | logger.info(f"convert {civ}") 35 | 36 | with open(civ, "r") as f: 37 | # trim .civitai.info 38 | name = os.path.splitext(os.path.splitext(os.path.basename(civ))[0])[0] 39 | 40 | output_path = out_dir.joinpath(name + ".json") 41 | 42 | if os.path.isfile(output_path): 43 | logger.info("already converted -> skip") 44 | continue 45 | 46 | if os.path.isfile( lora_abs_dir.joinpath(name + ".safetensors")): 47 | lora_path = os.path.relpath(lora_abs_dir.joinpath(name + ".safetensors"), data_dir) 48 | elif os.path.isfile( lora_abs_dir.joinpath(name + ".ckpt")): 49 | lora_path = os.path.relpath(lora_abs_dir.joinpath(name + ".ckpt"), data_dir) 50 | else: 51 | logger.info("lora file not found -> skip") 52 | continue 53 | 54 | info = json.load(f) 55 | 56 | if not info: 57 | logger.info(f"empty civitai info -> skip") 58 | continue 59 | 60 | if info["model"]["type"] not in ("LORA","lora"): 61 | logger.info(f"unsupported type {info['model']['type']} -> skip") 62 | continue 63 | 64 | new_config = org_config.copy() 65 | 66 | new_config["name"] = name 67 | 68 | new_prompt_map = {} 69 | new_n_prompt = "" 70 | new_seed = -1 71 | 72 | 73 | raw_prompt_map = {} 74 | 75 | i = 0 76 | for img_info in info["images"]: 77 | if img_info["meta"]: 78 | try: 79 | raw_prompt = img_info["meta"]["prompt"] 80 | except Exception as e: 81 | logger.info("missing prompt") 82 | continue 83 | 84 | raw_prompt_map[str(10000 + i*32)] = raw_prompt 85 | 86 | new_prompt_map[str(i*32)] = re.sub(extra_loading_regex, '', raw_prompt) 87 | 88 | if not new_n_prompt: 89 | try: 90 | new_n_prompt = img_info["meta"]["negativePrompt"] 91 | except Exception as e: 92 | new_n_prompt = "" 93 | if new_seed == -1: 94 | try: 95 | new_seed = img_info["meta"]["seed"] 96 | except Exception as e: 97 | new_seed = -1 98 | 99 | i += 1 100 | 101 | if not new_prompt_map: 102 | new_prompt_map[str(0)] = "" 103 | 104 | for k in raw_prompt_map: 105 | # comment 106 | new_prompt_map[k] = raw_prompt_map[k] 107 | 108 | new_config["prompt_map"] = new_prompt_map 109 | new_config["n_prompt"] = [new_n_prompt] 110 | new_config["seed"] = [new_seed] 111 | 112 | new_config["lora_map"] = {lora_path.replace(os.sep,'/'):lora_weight} 113 | 114 | with open( out_dir.joinpath(name + ".json"), 'w') as wf: 115 | json.dump(new_config, wf, indent=4) 116 | logger.info("converted!") 117 | 118 | preview = lora_abs_dir.joinpath(name + ".preview.png") 119 | if preview.is_file(): 120 | shutil.copy(preview, out_dir.joinpath(name + ".preview.png")) 121 | 122 | 123 | -------------------------------------------------------------------------------- /src/animatediff/utils/device.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import lru_cache 3 | from math import ceil 4 | from typing import Union 5 | 6 | import torch 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def device_info_str(device: torch.device) -> str: 12 | device_info = torch.cuda.get_device_properties(device) 13 | return ( 14 | f"{device_info.name} {ceil(device_info.total_memory / 1024 ** 3)}GB, " 15 | + f"CC {device_info.major}.{device_info.minor}, {device_info.multi_processor_count} SM(s)" 16 | ) 17 | 18 | 19 | @lru_cache(maxsize=4) 20 | def supports_bfloat16(device: Union[str, torch.device]) -> bool: 21 | """A non-exhaustive check for bfloat16 support on a given device. 22 | Weird that torch doesn't have a global function for this. If your device 23 | does support bfloat16 and it's not listed here, go ahead and add it. 24 | """ 25 | device = torch.device(device) # make sure device is a torch.device 26 | match device.type: 27 | case "cpu": 28 | ret = False 29 | case "cuda": 30 | with device: 31 | ret = torch.cuda.is_bf16_supported() 32 | case "xla": 33 | ret = True 34 | case "mps": 35 | ret = True 36 | case _: 37 | ret = False 38 | return ret 39 | 40 | 41 | @lru_cache(maxsize=4) 42 | def maybe_bfloat16( 43 | device: Union[str, torch.device], 44 | fallback: torch.dtype = torch.float32, 45 | ) -> torch.dtype: 46 | """Returns torch.bfloat16 if available, otherwise the fallback dtype (default float32)""" 47 | device = torch.device(device) # make sure device is a torch.device 48 | return torch.bfloat16 if supports_bfloat16(device) else fallback 49 | 50 | 51 | def dtype_for_model(model: str, device: torch.device) -> torch.dtype: 52 | match model: 53 | case "unet": 54 | return torch.float32 if device.type == "cpu" else torch.float16 55 | case "tenc": 56 | return torch.float32 if device.type == "cpu" else torch.float16 57 | case "vae": 58 | return maybe_bfloat16(device, fallback=torch.float32) 59 | case unknown: 60 | raise ValueError(f"Invalid model {unknown}") 61 | 62 | 63 | def get_model_dtypes( 64 | device: Union[str, torch.device], 65 | force_half_vae: bool = False, 66 | ) -> tuple[torch.dtype, torch.dtype, torch.dtype]: 67 | device = torch.device(device) # make sure device is a torch.device 68 | unet_dtype = dtype_for_model("unet", device) 69 | tenc_dtype = dtype_for_model("tenc", device) 70 | vae_dtype = dtype_for_model("vae", device) 71 | 72 | if device.type == "cpu": 73 | logger.warn("Device explicitly set to CPU, will run everything in fp32") 74 | logger.warn("This is likely to be *incredibly* slow, but I don't tell you how to live.") 75 | 76 | if force_half_vae: 77 | if device.type == "cpu": 78 | logger.critical("Can't force VAE to fp16 mode on CPU! Exiting...") 79 | raise RuntimeError("Can't force VAE to fp16 mode on CPU!") 80 | if vae_dtype == torch.bfloat16: 81 | logger.warn("Forcing VAE to use fp16 despite bfloat16 support! This is a bad idea!") 82 | logger.warn("If you're not sure why you're doing this, you probably shouldn't be.") 83 | vae_dtype = torch.float16 84 | else: 85 | logger.warn("Forcing VAE to use fp16 instead of fp32 on CUDA! This may result in black outputs!") 86 | logger.warn("Running a VAE in fp16 can result in black images or poor output quality.") 87 | logger.warn("I don't tell you how to live, but you probably shouldn't do this.") 88 | vae_dtype = torch.float16 89 | 90 | logger.info(f"Selected data types: {unet_dtype=}, {tenc_dtype=}, {vae_dtype=}") 91 | return unet_dtype, tenc_dtype, vae_dtype 92 | 93 | 94 | def get_memory_format(device: Union[str, torch.device]) -> torch.memory_format: 95 | device = torch.device(device) # make sure device is a torch.device 96 | # if we have a cuda device 97 | if device.type == "cuda": 98 | device_info = torch.cuda.get_device_properties(device) 99 | # Volta and newer seem to like channels_last. This will probably bite me on TU11x cards. 100 | if device_info.major >= 7: 101 | ret = torch.channels_last 102 | else: 103 | ret = torch.contiguous_format 104 | elif device.type == "xpu": 105 | # Intel ARC GPUs/XPUs like channels_last 106 | ret = torch.channels_last 107 | else: 108 | # TODO: Does MPS like channels_last? do other devices? 109 | ret = torch.contiguous_format 110 | if ret == torch.channels_last: 111 | logger.info("Using channels_last memory format for UNet and VAE") 112 | return ret -------------------------------------------------------------------------------- /src/animatediff/utils/pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | import torch 5 | import torch._dynamo as dynamo 6 | from diffusers import (DiffusionPipeline, StableDiffusionPipeline, 7 | StableDiffusionXLPipeline) 8 | from einops._torch_specific import allow_ops_in_compiled_graph 9 | 10 | from animatediff.utils.device import get_memory_format, get_model_dtypes 11 | from animatediff.utils.model import nop_train 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def send_to_device( 17 | pipeline: DiffusionPipeline, 18 | device: torch.device, 19 | freeze: bool = True, 20 | force_half: bool = False, 21 | compile: bool = False, 22 | is_sdxl: bool = False, 23 | ) -> DiffusionPipeline: 24 | if is_sdxl: 25 | return send_to_device_sdxl( 26 | pipeline=pipeline, 27 | device=device, 28 | freeze=freeze, 29 | force_half=force_half, 30 | compile=compile, 31 | ) 32 | 33 | logger.info(f"Sending pipeline to device \"{device.type}{device.index if device.index else ''}\"") 34 | 35 | unet_dtype, tenc_dtype, vae_dtype = get_model_dtypes(device, force_half) 36 | model_memory_format = get_memory_format(device) 37 | 38 | if hasattr(pipeline, 'controlnet'): 39 | unet_dtype = tenc_dtype = vae_dtype 40 | 41 | logger.info(f"-> Selected data types: {unet_dtype=},{tenc_dtype=},{vae_dtype=}") 42 | 43 | if hasattr(pipeline.controlnet, 'nets'): 44 | for i in range(len(pipeline.controlnet.nets)): 45 | pipeline.controlnet.nets[i] = pipeline.controlnet.nets[i].to(device=device, dtype=vae_dtype, memory_format=model_memory_format) 46 | else: 47 | if pipeline.controlnet: 48 | pipeline.controlnet = pipeline.controlnet.to(device=device, dtype=vae_dtype, memory_format=model_memory_format) 49 | 50 | if hasattr(pipeline, 'controlnet_map'): 51 | if pipeline.controlnet_map: 52 | for c in pipeline.controlnet_map: 53 | #pipeline.controlnet_map[c] = pipeline.controlnet_map[c].to(device=device, dtype=unet_dtype, memory_format=model_memory_format) 54 | pipeline.controlnet_map[c] = pipeline.controlnet_map[c].to(dtype=unet_dtype, memory_format=model_memory_format) 55 | 56 | if hasattr(pipeline, 'lora_map'): 57 | if pipeline.lora_map: 58 | pipeline.lora_map.to(device=device, dtype=unet_dtype) 59 | 60 | if hasattr(pipeline, 'lcm'): 61 | if pipeline.lcm: 62 | pipeline.lcm.to(device=device, dtype=unet_dtype) 63 | 64 | pipeline.unet = pipeline.unet.to(device=device, dtype=unet_dtype, memory_format=model_memory_format) 65 | pipeline.text_encoder = pipeline.text_encoder.to(device=device, dtype=tenc_dtype) 66 | pipeline.vae = pipeline.vae.to(device=device, dtype=vae_dtype, memory_format=model_memory_format) 67 | 68 | # Compile model if enabled 69 | if compile: 70 | if not isinstance(pipeline.unet, dynamo.OptimizedModule): 71 | allow_ops_in_compiled_graph() # make einops behave 72 | logger.warn("Enabling model compilation with TorchDynamo, this may take a while...") 73 | logger.warn("Model compilation is experimental and may not work as expected!") 74 | pipeline.unet = torch.compile( 75 | pipeline.unet, 76 | backend="inductor", 77 | mode="reduce-overhead", 78 | ) 79 | else: 80 | logger.debug("Skipping model compilation, already compiled!") 81 | 82 | return pipeline 83 | 84 | 85 | def send_to_device_sdxl( 86 | pipeline: StableDiffusionXLPipeline, 87 | device: torch.device, 88 | freeze: bool = True, 89 | force_half: bool = False, 90 | compile: bool = False, 91 | ) -> StableDiffusionXLPipeline: 92 | logger.info(f"Sending pipeline to device \"{device.type}{device.index if device.index else ''}\"") 93 | 94 | pipeline.unet = pipeline.unet.half() 95 | pipeline.text_encoder = pipeline.text_encoder.half() 96 | pipeline.text_encoder_2 = pipeline.text_encoder_2.half() 97 | 98 | if False: 99 | pipeline.to(device) 100 | else: 101 | pipeline.enable_model_cpu_offload() 102 | 103 | pipeline.enable_xformers_memory_efficient_attention() 104 | pipeline.enable_vae_slicing() 105 | pipeline.enable_vae_tiling() 106 | 107 | return pipeline 108 | 109 | 110 | 111 | def get_context_params( 112 | length: int, 113 | context: Optional[int] = None, 114 | overlap: Optional[int] = None, 115 | stride: Optional[int] = None, 116 | ): 117 | if context is None: 118 | context = min(length, 16) 119 | if overlap is None: 120 | overlap = context // 4 121 | if stride is None: 122 | stride = 0 123 | return context, overlap, stride 124 | -------------------------------------------------------------------------------- /src/animatediff/dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | import cv2 3 | import numpy as np 4 | import onnxruntime 5 | 6 | 7 | def nms(boxes, scores, nms_thr): 8 | """Single class NMS implemented in Numpy.""" 9 | x1 = boxes[:, 0] 10 | y1 = boxes[:, 1] 11 | x2 = boxes[:, 2] 12 | y2 = boxes[:, 3] 13 | 14 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 15 | order = scores.argsort()[::-1] 16 | 17 | keep = [] 18 | while order.size > 0: 19 | i = order[0] 20 | keep.append(i) 21 | xx1 = np.maximum(x1[i], x1[order[1:]]) 22 | yy1 = np.maximum(y1[i], y1[order[1:]]) 23 | xx2 = np.minimum(x2[i], x2[order[1:]]) 24 | yy2 = np.minimum(y2[i], y2[order[1:]]) 25 | 26 | w = np.maximum(0.0, xx2 - xx1 + 1) 27 | h = np.maximum(0.0, yy2 - yy1 + 1) 28 | inter = w * h 29 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 30 | 31 | inds = np.where(ovr <= nms_thr)[0] 32 | order = order[inds + 1] 33 | 34 | return keep 35 | 36 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 37 | """Multiclass NMS implemented in Numpy. Class-aware version.""" 38 | final_dets = [] 39 | num_classes = scores.shape[1] 40 | for cls_ind in range(num_classes): 41 | cls_scores = scores[:, cls_ind] 42 | valid_score_mask = cls_scores > score_thr 43 | if valid_score_mask.sum() == 0: 44 | continue 45 | else: 46 | valid_scores = cls_scores[valid_score_mask] 47 | valid_boxes = boxes[valid_score_mask] 48 | keep = nms(valid_boxes, valid_scores, nms_thr) 49 | if len(keep) > 0: 50 | cls_inds = np.ones((len(keep), 1)) * cls_ind 51 | dets = np.concatenate( 52 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 53 | ) 54 | final_dets.append(dets) 55 | if len(final_dets) == 0: 56 | return None 57 | return np.concatenate(final_dets, 0) 58 | 59 | def demo_postprocess(outputs, img_size, p6=False): 60 | grids = [] 61 | expanded_strides = [] 62 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 63 | 64 | hsizes = [img_size[0] // stride for stride in strides] 65 | wsizes = [img_size[1] // stride for stride in strides] 66 | 67 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 68 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 69 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 70 | grids.append(grid) 71 | shape = grid.shape[:2] 72 | expanded_strides.append(np.full((*shape, 1), stride)) 73 | 74 | grids = np.concatenate(grids, 1) 75 | expanded_strides = np.concatenate(expanded_strides, 1) 76 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 77 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 78 | 79 | return outputs 80 | 81 | def preprocess(img, input_size, swap=(2, 0, 1)): 82 | if len(img.shape) == 3: 83 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 84 | else: 85 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 86 | 87 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 88 | resized_img = cv2.resize( 89 | img, 90 | (int(img.shape[1] * r), int(img.shape[0] * r)), 91 | interpolation=cv2.INTER_LINEAR, 92 | ).astype(np.uint8) 93 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 94 | 95 | padded_img = padded_img.transpose(swap) 96 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 97 | return padded_img, r 98 | 99 | def inference_detector(session, oriImg): 100 | input_shape = (640,640) 101 | img, ratio = preprocess(oriImg, input_shape) 102 | 103 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 104 | output = session.run(None, ort_inputs) 105 | predictions = demo_postprocess(output[0], input_shape)[0] 106 | 107 | boxes = predictions[:, :4] 108 | scores = predictions[:, 4:5] * predictions[:, 5:] 109 | 110 | boxes_xyxy = np.ones_like(boxes) 111 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 112 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 113 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 114 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 115 | boxes_xyxy /= ratio 116 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 117 | if dets is not None: 118 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 119 | isscore = final_scores>0.3 120 | iscat = final_cls_inds == 0 121 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 122 | final_boxes = final_boxes[isbbox] 123 | else: 124 | return [] 125 | 126 | return final_boxes 127 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### Python ### 49 | # Byte-compiled / optimized / DLL files 50 | __pycache__/ 51 | *.py[cod] 52 | *$py.class 53 | 54 | # C extensions 55 | *.so 56 | 57 | # Distribution / packaging 58 | .Python 59 | build/ 60 | develop-eggs/ 61 | dist/ 62 | downloads/ 63 | eggs/ 64 | .eggs/ 65 | lib/ 66 | lib64/ 67 | parts/ 68 | sdist/ 69 | var/ 70 | wheels/ 71 | share/python-wheels/ 72 | *.egg-info/ 73 | .installed.cfg 74 | *.egg 75 | MANIFEST 76 | 77 | # PyInstaller 78 | # Usually these files are written by a python script from a template 79 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 80 | *.manifest 81 | *.spec 82 | 83 | # Installer logs 84 | pip-log.txt 85 | pip-delete-this-directory.txt 86 | 87 | # Unit test / coverage reports 88 | htmlcov/ 89 | .tox/ 90 | .nox/ 91 | .coverage 92 | .coverage.* 93 | .cache 94 | nosetests.xml 95 | coverage.xml 96 | *.cover 97 | *.py,cover 98 | .hypothesis/ 99 | .pytest_cache/ 100 | cover/ 101 | 102 | # Translations 103 | *.mo 104 | *.pot 105 | 106 | # Django stuff: 107 | *.log 108 | local_settings.py 109 | db.sqlite3 110 | db.sqlite3-journal 111 | 112 | # Flask stuff: 113 | instance/ 114 | .webassets-cache 115 | 116 | # Scrapy stuff: 117 | .scrapy 118 | 119 | # Sphinx documentation 120 | docs/_build/ 121 | 122 | # PyBuilder 123 | .pybuilder/ 124 | target/ 125 | 126 | # Jupyter Notebook 127 | .ipynb_checkpoints 128 | 129 | # IPython 130 | profile_default/ 131 | ipython_config.py 132 | 133 | # pyenv 134 | # For a library or package, you might want to ignore these files since the code is 135 | # intended to run in multiple environments; otherwise, check them in: 136 | # .python-version 137 | 138 | # pipenv 139 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 140 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 141 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 142 | # install all needed dependencies. 143 | #Pipfile.lock 144 | 145 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 146 | __pypackages__/ 147 | 148 | # Celery stuff 149 | celerybeat-schedule 150 | celerybeat.pid 151 | 152 | # SageMath parsed files 153 | *.sage.py 154 | 155 | # Environments 156 | .env 157 | .venv 158 | env/ 159 | venv/ 160 | ENV/ 161 | env.bak/ 162 | venv.bak/ 163 | 164 | # Spyder project settings 165 | .spyderproject 166 | .spyproject 167 | 168 | # Rope project settings 169 | .ropeproject 170 | 171 | # mkdocs documentation 172 | /site 173 | 174 | # mypy 175 | .mypy_cache/ 176 | .dmypy.json 177 | dmypy.json 178 | 179 | # Pyre type checker 180 | .pyre/ 181 | 182 | # pytype static type analyzer 183 | .pytype/ 184 | 185 | # Cython debug symbols 186 | cython_debug/ 187 | 188 | ### VisualStudioCode ### 189 | .vscode/* 190 | !.vscode/settings.json 191 | !.vscode/tasks.json 192 | !.vscode/launch.json 193 | !.vscode/extensions.json 194 | *.code-workspace 195 | 196 | # Local History for Visual Studio Code 197 | .history/ 198 | 199 | ### VisualStudioCode Patch ### 200 | # Ignore all local history of files 201 | .history 202 | .ionide 203 | 204 | ### Windows ### 205 | # Windows thumbnail cache files 206 | Thumbs.db 207 | Thumbs.db:encryptable 208 | ehthumbs.db 209 | ehthumbs_vista.db 210 | 211 | # Dump file 212 | *.stackdump 213 | 214 | # Folder config file 215 | [Dd]esktop.ini 216 | 217 | # Recycle Bin used on file shares 218 | $RECYCLE.BIN/ 219 | 220 | # Windows Installer files 221 | *.cab 222 | *.msi 223 | *.msix 224 | *.msm 225 | *.msp 226 | 227 | # Windows shortcuts 228 | *.lnk 229 | 230 | # End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python 231 | 232 | # setuptools-scm _version file 233 | src/animatediff/_version.py 234 | 235 | # local misc and temp 236 | /misc/ 237 | /temp/ 238 | 239 | # envrc 240 | .env* 241 | !.envrc.example 242 | -------------------------------------------------------------------------------- /src/animatediff/softmax_splatting/README.md: -------------------------------------------------------------------------------- 1 | # softmax-splatting 2 | This is a reference implementation of the softmax splatting operator, which has been proposed in Softmax Splatting for Video Frame Interpolation [1], using PyTorch. Softmax splatting is a well-motivated approach for differentiable forward warping. It uses a translational invariant importance metric to disambiguate cases where multiple source pixels map to the same target pixel. Should you be making use of our work, please cite our paper [1]. 3 | 4 | Paper 5 | 6 | For our previous work on SepConv, see: https://github.com/sniklaus/revisiting-sepconv 7 | 8 | ## setup 9 | The softmax splatting is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided [binary packages](https://docs.cupy.dev/en/stable/install.html#installing-cupy) as outlined in the CuPy repository. 10 | 11 | If you plan to process videos, then please also make sure to have `pip install moviepy` installed. 12 | 13 | ## usage 14 | To run it on your own pair of frames, use the following command. 15 | 16 | ``` 17 | python run.py --model lf --one ./images/one.png --two ./images/two.png --out ./out.png 18 | ``` 19 | 20 | To run in on a video, use the following command. 21 | 22 | ``` 23 | python run.py --model lf --video ./videos/car-turn.mp4 --out ./out.mp4 24 | ``` 25 | 26 | For a quick benchmark using examples from the Middlebury benchmark for optical flow, run `python benchmark_middlebury.py`. You can use it to easily verify that the provided implementation runs as expected. 27 | 28 | ## warping 29 | We provide a small script to replicate the third figure of our paper [1]. You can simply run the following to obtain the comparison between summation splatting, average splatting, linear splatting, and softmax splatting. 30 | 31 | The example script is using OpenCV to load and display images, as well as to read the provided optical flow file. An easy way to install OpenCV for Python is using the `pip install opencv-contrib-python` package. 32 | 33 | ``` 34 | import cv2 35 | import numpy 36 | import torch 37 | 38 | import run 39 | 40 | import softsplat # the custom softmax splatting layer 41 | 42 | ########################################################## 43 | 44 | torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance 45 | 46 | torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance 47 | 48 | ########################################################## 49 | 50 | tenOne = torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/one.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda() 51 | tenTwo = torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/two.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda() 52 | tenFlow = torch.FloatTensor(numpy.ascontiguousarray(run.read_flo('./images/flow.flo').transpose(2, 0, 1)[None, :, :, :])).cuda() 53 | 54 | tenMetric = torch.nn.functional.l1_loss(input=tenOne, target=run.backwarp(tenIn=tenTwo, tenFlow=tenFlow), reduction='none').mean([1], True) 55 | 56 | for intTime, fltTime in enumerate(numpy.linspace(0.0, 1.0, 11).tolist()): 57 | tenSummation = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=None, strMode='sum') 58 | tenAverage = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=None, strMode='avg') 59 | tenLinear = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=(0.3 - tenMetric).clip(0.001, 1.0), strMode='linear') # finding a good linearly metric is difficult, and it is not invariant to translations 60 | tenSoftmax = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=(-20.0 * tenMetric).clip(-20.0, 20.0), strMode='soft') # -20.0 is a hyperparameter, called 'alpha' in the paper, that could be learned using a torch.Parameter 61 | 62 | cv2.imshow(winname='summation', mat=tenSummation[0, :, :, :].cpu().numpy().transpose(1, 2, 0)) 63 | cv2.imshow(winname='average', mat=tenAverage[0, :, :, :].cpu().numpy().transpose(1, 2, 0)) 64 | cv2.imshow(winname='linear', mat=tenLinear[0, :, :, :].cpu().numpy().transpose(1, 2, 0)) 65 | cv2.imshow(winname='softmax', mat=tenSoftmax[0, :, :, :].cpu().numpy().transpose(1, 2, 0)) 66 | cv2.waitKey(delay=0) 67 | # end 68 | ``` 69 | 70 | ## xiph 71 | In our paper, we propose to use 4K video clips from Xiph to evaluate video frame interpolation on high-resolution footage. Please see the supplementary `benchmark_xiph.py` on how to reproduce the shown metrics. 72 | 73 | ## video 74 | Video 75 | 76 | ## license 77 | The provided implementation is strictly for academic purposes only. Should you be interested in using our technology for any commercial use, please feel free to contact us. 78 | 79 | ## references 80 | ``` 81 | [1] @inproceedings{Niklaus_CVPR_2020, 82 | author = {Simon Niklaus and Feng Liu}, 83 | title = {Softmax Splatting for Video Frame Interpolation}, 84 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, 85 | year = {2020} 86 | } 87 | ``` 88 | 89 | ## acknowledgment 90 | The video above uses materials under a Creative Common license as detailed at the end. -------------------------------------------------------------------------------- /src/animatediff/utils/huggingface.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from os import PathLike 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline 7 | from huggingface_hub import hf_hub_download, snapshot_download 8 | from tqdm.rich import tqdm 9 | 10 | from animatediff import HF_HUB_CACHE, HF_LIB_NAME, HF_LIB_VER, get_dir 11 | from animatediff.utils.util import path_from_cwd 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | data_dir = get_dir("data") 16 | checkpoint_dir = data_dir.joinpath("models/sd") 17 | pipeline_dir = data_dir.joinpath("models/huggingface") 18 | 19 | IGNORE_TF = ["*.git*", "*.h5", "tf_*"] 20 | IGNORE_FLAX = ["*.git*", "flax_*", "*.msgpack"] 21 | IGNORE_TF_FLAX = IGNORE_TF + IGNORE_FLAX 22 | 23 | 24 | class DownloadTqdm(tqdm): 25 | def __init__(self, *args, **kwargs): 26 | kwargs.update( 27 | { 28 | "ncols": 100, 29 | "dynamic_ncols": False, 30 | "disable": None, 31 | } 32 | ) 33 | super().__init__(*args, **kwargs) 34 | 35 | 36 | def get_hf_file( 37 | repo_id: Path, 38 | filename: str, 39 | target_dir: Path, 40 | subfolder: Optional[PathLike] = None, 41 | revision: Optional[str] = None, 42 | force: bool = False, 43 | ) -> Path: 44 | target_path = target_dir.joinpath(filename) 45 | if target_path.exists() and force is not True: 46 | raise FileExistsError( 47 | f"File {path_from_cwd(target_path)} already exists! Pass force=True to overwrite" 48 | ) 49 | 50 | target_dir.mkdir(exist_ok=True, parents=True) 51 | save_path = hf_hub_download( 52 | repo_id=str(repo_id), 53 | filename=filename, 54 | revision=revision or "main", 55 | subfolder=subfolder, 56 | local_dir=target_dir, 57 | local_dir_use_symlinks=False, 58 | cache_dir=HF_HUB_CACHE, 59 | resume_download=True, 60 | ) 61 | return Path(save_path) 62 | 63 | 64 | def get_hf_repo( 65 | repo_id: Path, 66 | target_dir: Path, 67 | subfolder: Optional[PathLike] = None, 68 | revision: Optional[str] = None, 69 | force: bool = False, 70 | ) -> Path: 71 | if target_dir.exists() and force is not True: 72 | raise FileExistsError( 73 | f"Target dir {path_from_cwd(target_dir)} already exists! Pass force=True to overwrite" 74 | ) 75 | 76 | target_dir.mkdir(exist_ok=True, parents=True) 77 | save_path = snapshot_download( 78 | repo_id=str(repo_id), 79 | revision=revision or "main", 80 | subfolder=subfolder, 81 | library_name=HF_LIB_NAME, 82 | library_version=HF_LIB_VER, 83 | local_dir=target_dir, 84 | local_dir_use_symlinks=False, 85 | ignore_patterns=IGNORE_TF_FLAX, 86 | cache_dir=HF_HUB_CACHE, 87 | tqdm_class=DownloadTqdm, 88 | max_workers=2, 89 | resume_download=True, 90 | ) 91 | return Path(save_path) 92 | 93 | 94 | def get_hf_pipeline( 95 | repo_id: Path, 96 | target_dir: Path, 97 | save: bool = True, 98 | force_download: bool = False, 99 | ) -> StableDiffusionPipeline: 100 | pipeline_exists = target_dir.joinpath("model_index.json").exists() 101 | if pipeline_exists and force_download is not True: 102 | pipeline = StableDiffusionPipeline.from_pretrained( 103 | pretrained_model_name_or_path=target_dir, 104 | local_files_only=True, 105 | ) 106 | else: 107 | target_dir.mkdir(exist_ok=True, parents=True) 108 | pipeline = StableDiffusionPipeline.from_pretrained( 109 | pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"), 110 | cache_dir=HF_HUB_CACHE, 111 | resume_download=True, 112 | ) 113 | if save and force_download: 114 | logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!") 115 | pipeline.save_pretrained(target_dir, safe_serialization=True) 116 | elif save and not pipeline_exists: 117 | logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}") 118 | pipeline.save_pretrained(target_dir, safe_serialization=True) 119 | return pipeline 120 | 121 | def get_hf_pipeline_sdxl( 122 | repo_id: Path, 123 | target_dir: Path, 124 | save: bool = True, 125 | force_download: bool = False, 126 | ) -> StableDiffusionXLPipeline: 127 | import torch 128 | pipeline_exists = target_dir.joinpath("model_index.json").exists() 129 | if pipeline_exists and force_download is not True: 130 | pipeline = StableDiffusionXLPipeline.from_pretrained( 131 | pretrained_model_name_or_path=target_dir, 132 | local_files_only=True, 133 | torch_dtype=torch.float16, use_safetensors=True, variant="fp16" 134 | ) 135 | else: 136 | target_dir.mkdir(exist_ok=True, parents=True) 137 | pipeline = StableDiffusionXLPipeline.from_pretrained( 138 | pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"), 139 | cache_dir=HF_HUB_CACHE, 140 | resume_download=True, 141 | torch_dtype=torch.float16, use_safetensors=True, variant="fp16" 142 | ) 143 | if save and force_download: 144 | logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!") 145 | pipeline.save_pretrained(target_dir, safe_serialization=True) 146 | elif save and not pipeline_exists: 147 | logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}") 148 | pipeline.save_pretrained(target_dir, safe_serialization=True) 149 | return pipeline 150 | -------------------------------------------------------------------------------- /src/animatediff/ip_adapter/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from einops import rearrange 8 | from einops.layers.torch import Rearrange 9 | 10 | 11 | # FFN 12 | def FeedForward(dim, mult=4): 13 | inner_dim = int(dim * mult) 14 | return nn.Sequential( 15 | nn.LayerNorm(dim), 16 | nn.Linear(dim, inner_dim, bias=False), 17 | nn.GELU(), 18 | nn.Linear(inner_dim, dim, bias=False), 19 | ) 20 | 21 | 22 | def reshape_tensor(x, heads): 23 | bs, length, width = x.shape 24 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 25 | x = x.view(bs, length, heads, -1) 26 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 27 | x = x.transpose(1, 2) 28 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 29 | x = x.reshape(bs, heads, length, -1) 30 | return x 31 | 32 | 33 | class PerceiverAttention(nn.Module): 34 | def __init__(self, *, dim, dim_head=64, heads=8): 35 | super().__init__() 36 | self.scale = dim_head**-0.5 37 | self.dim_head = dim_head 38 | self.heads = heads 39 | inner_dim = dim_head * heads 40 | 41 | self.norm1 = nn.LayerNorm(dim) 42 | self.norm2 = nn.LayerNorm(dim) 43 | 44 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 45 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 46 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 47 | 48 | 49 | def forward(self, x, latents): 50 | """ 51 | Args: 52 | x (torch.Tensor): image features 53 | shape (b, n1, D) 54 | latent (torch.Tensor): latent features 55 | shape (b, n2, D) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 74 | out = weight @ v 75 | 76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 77 | 78 | return self.to_out(out) 79 | 80 | 81 | class Resampler(nn.Module): 82 | def __init__( 83 | self, 84 | dim=1024, 85 | depth=8, 86 | dim_head=64, 87 | heads=16, 88 | num_queries=8, 89 | embedding_dim=768, 90 | output_dim=1024, 91 | ff_mult=4, 92 | max_seq_len: int = 257, # CLIP tokens + CLS token 93 | apply_pos_emb: bool = False, 94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 95 | ): 96 | super().__init__() 97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None 98 | 99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 100 | 101 | self.proj_in = nn.Linear(embedding_dim, dim) 102 | 103 | self.proj_out = nn.Linear(dim, output_dim) 104 | self.norm_out = nn.LayerNorm(output_dim) 105 | 106 | self.to_latents_from_mean_pooled_seq = ( 107 | nn.Sequential( 108 | nn.LayerNorm(dim), 109 | nn.Linear(dim, dim * num_latents_mean_pooled), 110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), 111 | ) 112 | if num_latents_mean_pooled > 0 113 | else None 114 | ) 115 | 116 | self.layers = nn.ModuleList([]) 117 | for _ in range(depth): 118 | self.layers.append( 119 | nn.ModuleList( 120 | [ 121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 122 | FeedForward(dim=dim, mult=ff_mult), 123 | ] 124 | ) 125 | ) 126 | 127 | def forward(self, x): 128 | if self.pos_emb is not None: 129 | n, device = x.shape[1], x.device 130 | pos_emb = self.pos_emb(torch.arange(n, device=device)) 131 | x = x + pos_emb 132 | 133 | latents = self.latents.repeat(x.size(0), 1, 1) 134 | 135 | x = self.proj_in(x) 136 | 137 | if self.to_latents_from_mean_pooled_seq: 138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) 139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 140 | latents = torch.cat((meanpooled_latents, latents), dim=-2) 141 | 142 | for attn, ff in self.layers: 143 | latents = attn(x, latents) + latents 144 | latents = ff(latents) + latents 145 | 146 | latents = self.proj_out(latents) 147 | return self.norm_out(latents) 148 | 149 | 150 | def masked_mean(t, *, dim, mask=None): 151 | if mask is None: 152 | return t.mean(dim=dim) 153 | 154 | denom = mask.sum(dim=dim, keepdim=True) 155 | mask = rearrange(mask, "b n -> b n 1") 156 | masked_t = t.masked_fill(~mask, 0.0) 157 | 158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 159 | -------------------------------------------------------------------------------- /src/animatediff/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ Conversion script for the LoRA's safetensors checkpoints. """ 17 | 18 | import argparse 19 | 20 | import torch 21 | 22 | 23 | def convert_lora( 24 | pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6 25 | ): 26 | # load base model 27 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 28 | 29 | # load LoRA weight from .safetensors 30 | # state_dict = load_file(checkpoint_path) 31 | 32 | visited = [] 33 | 34 | # directly update weight in diffusers model 35 | for key in state_dict: 36 | # it is suggested to print out the key, it usually will be something like below 37 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 38 | 39 | # as we have set the alpha beforehand, so just skip 40 | if ".alpha" in key or key in visited: 41 | continue 42 | 43 | if "text" in key: 44 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 45 | curr_layer = pipeline.text_encoder 46 | else: 47 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 48 | curr_layer = pipeline.unet 49 | 50 | # find the target layer 51 | temp_name = layer_infos.pop(0) 52 | while len(layer_infos) > -1: 53 | try: 54 | curr_layer = curr_layer.__getattr__(temp_name) 55 | if len(layer_infos) > 0: 56 | temp_name = layer_infos.pop(0) 57 | elif len(layer_infos) == 0: 58 | break 59 | except Exception: 60 | if len(temp_name) > 0: 61 | temp_name += "_" + layer_infos.pop(0) 62 | else: 63 | temp_name = layer_infos.pop(0) 64 | 65 | pair_keys = [] 66 | if "lora_down" in key: 67 | pair_keys.append(key.replace("lora_down", "lora_up")) 68 | pair_keys.append(key) 69 | else: 70 | pair_keys.append(key) 71 | pair_keys.append(key.replace("lora_up", "lora_down")) 72 | 73 | # update weight 74 | if len(state_dict[pair_keys[0]].shape) == 4: 75 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 76 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 77 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to( 78 | curr_layer.weight.data.device 79 | ) 80 | else: 81 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 82 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 83 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to( 84 | curr_layer.weight.data.device 85 | ) 86 | 87 | # update visited list 88 | for item in pair_keys: 89 | visited.append(item) 90 | 91 | return pipeline 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | 97 | parser.add_argument( 98 | "--base_model_path", 99 | default=None, 100 | type=str, 101 | required=True, 102 | help="Path to the base model in diffusers format.", 103 | ) 104 | parser.add_argument( 105 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 106 | ) 107 | parser.add_argument( 108 | "--dump_path", default=None, type=str, required=True, help="Path to the output model." 109 | ) 110 | parser.add_argument( 111 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 112 | ) 113 | parser.add_argument( 114 | "--lora_prefix_text_encoder", 115 | default="lora_te", 116 | type=str, 117 | help="The prefix of text encoder weight in safetensors", 118 | ) 119 | parser.add_argument( 120 | "--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW" 121 | ) 122 | parser.add_argument( 123 | "--to_safetensors", 124 | action="store_true", 125 | help="Whether to store pipeline in safetensors format or not.", 126 | ) 127 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 128 | 129 | args = parser.parse_args() 130 | 131 | base_model_path = args.base_model_path 132 | checkpoint_path = args.checkpoint_path 133 | dump_path = args.dump_path 134 | lora_prefix_unet = args.lora_prefix_unet 135 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 136 | alpha = args.alpha 137 | 138 | pipe = convert_lora(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 139 | 140 | pipe = pipe.to(args.device) 141 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 142 | -------------------------------------------------------------------------------- /src/animatediff/settings.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from os import PathLike 4 | from pathlib import Path 5 | from typing import Any, Dict, Optional, Tuple, Union 6 | 7 | from pydantic.v1 import BaseConfig, BaseSettings, Field 8 | from pydantic.env_settings import (EnvSettingsSource, InitSettingsSource, 9 | SecretsSettingsSource, 10 | SettingsSourceCallable) 11 | 12 | from animatediff import get_dir 13 | from animatediff.schedulers import DiffusionScheduler 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | CKPT_EXTENSIONS = [".pt", ".ckpt", ".pth", ".safetensors"] 19 | 20 | 21 | class JsonSettingsSource: 22 | __slots__ = ["json_config_path"] 23 | 24 | def __init__( 25 | self, 26 | json_config_path: Optional[Union[PathLike, list[PathLike]]] = list(), 27 | ) -> None: 28 | if isinstance(json_config_path, list): 29 | self.json_config_path = [Path(path) for path in json_config_path] 30 | else: 31 | self.json_config_path = [Path(json_config_path)] if json_config_path is not None else [] 32 | 33 | def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901 34 | classname = settings.__class__.__name__ 35 | encoding = settings.__config__.env_file_encoding 36 | if len(self.json_config_path) == 0: 37 | pass # no json config provided 38 | 39 | merged_config = dict() # create an empty dict to merge configs into 40 | for idx, path in enumerate(self.json_config_path): 41 | if path.exists() and path.is_file(): # check if the path exists and is a file 42 | logger.debug(f"{classname}: loading config #{idx+1} from {path}") 43 | merged_config.update(json.loads(path.read_text(encoding=encoding))) 44 | logger.debug(f"{classname}: config state #{idx+1}: {merged_config}") 45 | else: 46 | raise FileNotFoundError(f"{classname}: config #{idx+1} at {path} not found or not a file") 47 | 48 | logger.debug(f"{classname}: loaded config: {merged_config}") 49 | return merged_config # return the merged config 50 | 51 | def __repr__(self) -> str: 52 | return f"JsonSettingsSource(json_config_path={repr(self.json_config_path)})" 53 | 54 | 55 | class JsonConfig(BaseConfig): 56 | json_config_path: Optional[Union[Path, list[Path]]] = None 57 | env_file_encoding: str = "utf-8" 58 | 59 | @classmethod 60 | def customise_sources( 61 | cls, 62 | init_settings: InitSettingsSource, 63 | env_settings: EnvSettingsSource, 64 | file_secret_settings: SecretsSettingsSource, 65 | ) -> Tuple[SettingsSourceCallable, ...]: 66 | # pull json_config_path from init_settings if passed, otherwise use the class var 67 | json_config_path = init_settings.init_kwargs.pop("json_config_path", cls.json_config_path) 68 | 69 | logger.debug(f"Using JsonSettingsSource for {cls.__name__}") 70 | json_settings = JsonSettingsSource(json_config_path=json_config_path) 71 | 72 | # return the new settings sources 73 | return ( 74 | init_settings, 75 | json_settings, 76 | ) 77 | 78 | 79 | class InferenceConfig(BaseSettings): 80 | unet_additional_kwargs: dict[str, Any] 81 | noise_scheduler_kwargs: dict[str, Any] 82 | 83 | class Config(JsonConfig): 84 | json_config_path: Path 85 | 86 | 87 | def get_infer_config( 88 | is_v2:bool, 89 | is_sdxl:bool, 90 | ) -> InferenceConfig: 91 | config_path: Path = get_dir("config").joinpath("inference/default.json" if not is_v2 else "inference/motion_v2.json") 92 | 93 | if is_sdxl: 94 | config_path = get_dir("config").joinpath("inference/motion_sdxl.json") 95 | 96 | settings = InferenceConfig(json_config_path=config_path) 97 | return settings 98 | 99 | 100 | class ModelConfig(BaseSettings): 101 | name: str = Field(...) # Config name, not actually used for much of anything 102 | path: Path = Field(...) # Path to the model 103 | vae_path: str = "" # Path to the model 104 | motion_module: Path = Field(...) # Path to the motion module 105 | context_schedule: str = "uniform" 106 | lcm_map: Dict[str,Any]= Field({}) 107 | gradual_latent_hires_fix_map: Dict[str,Any]= Field({}) 108 | compile: bool = Field(False) # whether to compile the model with TorchDynamo 109 | tensor_interpolation_slerp: bool = Field(True) 110 | seed: list[int] = Field([]) # Seed(s) for the random number generators 111 | scheduler: DiffusionScheduler = Field(DiffusionScheduler.k_dpmpp_2m) # Scheduler to use 112 | steps: int = 25 # Number of inference steps to run 113 | guidance_scale: float = 7.5 # CFG scale to use 114 | unet_batch_size: int = 1 115 | clip_skip: int = 1 # skip the last N-1 layers of the CLIP text encoder 116 | prompt_fixed_ratio: float = 0.5 117 | head_prompt: str = "" 118 | prompt_map: Dict[str,str]= Field({}) 119 | tail_prompt: str = "" 120 | n_prompt: list[str] = Field([]) # Anti-prompt(s) to use 121 | is_single_prompt_mode : bool = Field(False) 122 | lora_map: Dict[str,Any]= Field({}) 123 | motion_lora_map: Dict[str,float]= Field({}) 124 | ip_adapter_map: Dict[str,Any]= Field({}) 125 | img2img_map: Dict[str,Any]= Field({}) 126 | region_map: Dict[str,Any]= Field({}) 127 | controlnet_map: Dict[str,Any]= Field({}) 128 | upscale_config: Dict[str,Any]= Field({}) 129 | stylize_config: Dict[str,Any]= Field({}) 130 | output: Dict[str,Any]= Field({}) 131 | result: Dict[str,Any]= Field({}) 132 | 133 | class Config(JsonConfig): 134 | json_config_path: Path 135 | 136 | @property 137 | def save_name(self): 138 | return f"{self.name.lower()}-{self.path.stem.lower()}" 139 | 140 | 141 | def get_model_config(config_path: Path) -> ModelConfig: 142 | settings = ModelConfig(json_config_path=config_path) 143 | return settings 144 | -------------------------------------------------------------------------------- /src/animatediff/utils/tagger.py: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py 2 | 3 | import glob 4 | import logging 5 | import os 6 | 7 | import cv2 8 | import numpy as np 9 | import onnxruntime 10 | import pandas as pd 11 | from PIL import Image 12 | from tqdm.rich import tqdm 13 | 14 | from animatediff.utils.util import prepare_wd14tagger 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def make_square(img, target_size): 20 | old_size = img.shape[:2] 21 | desired_size = max(old_size) 22 | desired_size = max(desired_size, target_size) 23 | 24 | delta_w = desired_size - old_size[1] 25 | delta_h = desired_size - old_size[0] 26 | top, bottom = delta_h // 2, delta_h - (delta_h // 2) 27 | left, right = delta_w // 2, delta_w - (delta_w // 2) 28 | 29 | color = [255, 255, 255] 30 | new_im = cv2.copyMakeBorder( 31 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color 32 | ) 33 | return new_im 34 | 35 | def smart_resize(img, size): 36 | # Assumes the image has already gone through make_square 37 | if img.shape[0] > size: 38 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 39 | elif img.shape[0] < size: 40 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) 41 | return img 42 | 43 | 44 | class Tagger: 45 | def __init__(self, general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format,is_cpu): 46 | prepare_wd14tagger() 47 | # self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CUDAExecutionProvider','CPUExecutionProvider']) 48 | if is_cpu: 49 | self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CPUExecutionProvider']) 50 | else: 51 | self.model = onnxruntime.InferenceSession("data/models/WD14tagger/model.onnx", providers=['CUDAExecutionProvider']) 52 | df = pd.read_csv("data/models/WD14tagger/selected_tags.csv") 53 | self.tag_names = df["name"].tolist() 54 | self.rating_indexes = list(np.where(df["category"] == 9)[0]) 55 | self.general_indexes = list(np.where(df["category"] == 0)[0]) 56 | self.character_indexes = list(np.where(df["category"] == 4)[0]) 57 | 58 | self.general_threshold = general_threshold 59 | self.character_threshold = character_threshold 60 | self.ignore_tokens = ignore_tokens 61 | self.with_confidence = with_confidence 62 | self.is_danbooru_format = is_danbooru_format 63 | 64 | def __call__( 65 | self, 66 | image: Image, 67 | ): 68 | 69 | _, height, width, _ = self.model.get_inputs()[0].shape 70 | 71 | # Alpha to white 72 | image = image.convert("RGBA") 73 | new_image = Image.new("RGBA", image.size, "WHITE") 74 | new_image.paste(image, mask=image) 75 | image = new_image.convert("RGB") 76 | image = np.asarray(image) 77 | 78 | # PIL RGB to OpenCV BGR 79 | image = image[:, :, ::-1] 80 | 81 | image = make_square(image, height) 82 | image = smart_resize(image, height) 83 | image = image.astype(np.float32) 84 | image = np.expand_dims(image, 0) 85 | 86 | input_name = self.model.get_inputs()[0].name 87 | label_name = self.model.get_outputs()[0].name 88 | probs = self.model.run([label_name], {input_name: image})[0] 89 | 90 | labels = list(zip(self.tag_names, probs[0].astype(float))) 91 | 92 | # First 4 labels are actually ratings: pick one with argmax 93 | ratings_names = [labels[i] for i in self.rating_indexes] 94 | rating = dict(ratings_names) 95 | 96 | # Then we have general tags: pick any where prediction confidence > threshold 97 | general_names = [labels[i] for i in self.general_indexes] 98 | general_res = [x for x in general_names if x[1] > self.general_threshold] 99 | general_res = dict(general_res) 100 | 101 | # Everything else is characters: pick any where prediction confidence > threshold 102 | character_names = [labels[i] for i in self.character_indexes] 103 | character_res = [x for x in character_names if x[1] > self.character_threshold] 104 | character_res = dict(character_res) 105 | 106 | #logger.info(f"{rating=}") 107 | #logger.info(f"{general_res=}") 108 | #logger.info(f"{character_res=}") 109 | 110 | general_res = {k:general_res[k] for k in (general_res.keys() - set(self.ignore_tokens)) } 111 | character_res = {k:character_res[k] for k in (character_res.keys() - set(self.ignore_tokens)) } 112 | 113 | prompt = "" 114 | 115 | if self.with_confidence: 116 | prompt = [ f"({i}:{character_res[i]:.2f})" for i in (character_res.keys()) ] 117 | prompt += [ f"({i}:{general_res[i]:.2f})" for i in (general_res.keys()) ] 118 | else: 119 | prompt = [ i for i in (character_res.keys()) ] 120 | prompt += [ i for i in (general_res.keys()) ] 121 | 122 | prompt = ",".join(prompt) 123 | 124 | if not self.is_danbooru_format: 125 | prompt = prompt.replace("_", " ") 126 | 127 | #logger.info(f"{prompt=}") 128 | return prompt 129 | 130 | 131 | def get_labels(frame_dir, interval, general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format, is_cpu =False): 132 | 133 | import torch 134 | 135 | result = {} 136 | if os.path.isdir(frame_dir): 137 | png_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False)) 138 | 139 | png_map ={} 140 | for png_path in png_list: 141 | basename_without_ext = os.path.splitext(os.path.basename(png_path))[0] 142 | png_map[int(basename_without_ext)] = png_path 143 | 144 | with torch.no_grad(): 145 | tagger = Tagger(general_threshold, character_threshold, ignore_tokens, with_confidence, is_danbooru_format, is_cpu) 146 | 147 | for i in tqdm(range(0, len(png_list), interval ), desc=f"WD14tagger"): 148 | path = png_map[i] 149 | 150 | #logger.info(f"{path=}") 151 | 152 | result[str(i)] = tagger( 153 | image= Image.open(path) 154 | ) 155 | 156 | tagger = None 157 | 158 | torch.cuda.empty_cache() 159 | 160 | return result 161 | 162 | -------------------------------------------------------------------------------- /src/animatediff/utils/composite.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import shutil 5 | from pathlib import Path 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from PIL import Image 12 | from tqdm.rich import tqdm 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | #https://github.com/jinwonkim93/laplacian-pyramid-blend 18 | #https://blog.shikoan.com/pytorch-laplacian-pyramid/ 19 | class LaplacianPyramidBlender: 20 | 21 | device = None 22 | 23 | def get_gaussian_kernel(self): 24 | kernel = np.array([ 25 | [1, 4, 6, 4, 1], 26 | [4, 16, 24, 16, 4], 27 | [6, 24, 36, 24, 6], 28 | [4, 16, 24, 16, 4], 29 | [1, 4, 6, 4, 1]], np.float32) / 256.0 30 | gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5),device=self.device) 31 | return gaussian_k 32 | 33 | def pyramid_down(self, image): 34 | with torch.no_grad(): 35 | gaussian_k = self.get_gaussian_kernel() 36 | multiband = [F.conv2d(image[:, i:i + 1,:,:], gaussian_k, padding=2, stride=2) for i in range(3)] 37 | down_image = torch.cat(multiband, dim=1) 38 | return down_image 39 | 40 | def pyramid_up(self, image, size = None): 41 | with torch.no_grad(): 42 | gaussian_k = self.get_gaussian_kernel() 43 | if size is None: 44 | upsample = F.interpolate(image, scale_factor=2) 45 | else: 46 | upsample = F.interpolate(image, size=size) 47 | multiband = [F.conv2d(upsample[:, i:i + 1,:,:], gaussian_k, padding=2) for i in range(3)] 48 | up_image = torch.cat(multiband, dim=1) 49 | return up_image 50 | 51 | def gaussian_pyramid(self, original, n_pyramids): 52 | x = original 53 | # pyramid down 54 | pyramids = [original] 55 | for i in range(n_pyramids): 56 | x = self.pyramid_down(x) 57 | pyramids.append(x) 58 | return pyramids 59 | 60 | def laplacian_pyramid(self, original, n_pyramids): 61 | pyramids = self.gaussian_pyramid(original, n_pyramids) 62 | 63 | # pyramid up - diff 64 | laplacian = [] 65 | for i in range(len(pyramids) - 1): 66 | diff = pyramids[i] - self.pyramid_up(pyramids[i + 1], pyramids[i].shape[2:]) 67 | laplacian.append(diff) 68 | 69 | laplacian.append(pyramids[-1]) 70 | return laplacian 71 | 72 | def laplacian_pyramid_blending_with_mask(self, src, target, mask, num_levels = 9): 73 | # assume mask is float32 [0,1] 74 | 75 | # generate Gaussian pyramid for src,target and mask 76 | 77 | Gsrc = torch.as_tensor(np.expand_dims(src, axis=0), device=self.device) 78 | Gtarget = torch.as_tensor(np.expand_dims(target, axis=0), device=self.device) 79 | Gmask = torch.as_tensor(np.expand_dims(mask, axis=0), device=self.device) 80 | 81 | lpA = self.laplacian_pyramid(Gsrc,num_levels)[::-1] 82 | lpB = self.laplacian_pyramid(Gtarget,num_levels)[::-1] 83 | gpMr = self.gaussian_pyramid(Gmask,num_levels)[::-1] 84 | 85 | # Now blend images according to mask in each level 86 | LS = [] 87 | for idx, (la,lb,Gmask) in enumerate(zip(lpA,lpB,gpMr)): 88 | lo = lb * (1.0 - Gmask) 89 | if idx <= 2: 90 | lo += lb * Gmask 91 | else: 92 | lo += la * Gmask 93 | LS.append(lo) 94 | 95 | # now reconstruct 96 | ls_ = LS.pop(0) 97 | for lap in LS: 98 | ls_ = self.pyramid_up(ls_, lap.shape[2:]) + lap 99 | 100 | result = ls_.squeeze(dim=0).to('cpu').detach().numpy().copy() 101 | 102 | return result 103 | 104 | def __call__(self, 105 | src_image: np.ndarray, 106 | target_image: np.ndarray, 107 | mask_image: np.ndarray, 108 | device 109 | ): 110 | 111 | self.device = device 112 | 113 | num_levels = int(np.log2(src_image.shape[0])) 114 | #normalize image to 0, 1 115 | mask_image = np.clip(mask_image, 0, 1).transpose([2, 0, 1]) 116 | 117 | src_image = src_image.transpose([2, 0, 1]).astype(np.float32) / 255.0 118 | target_image = target_image.transpose([2, 0, 1]).astype(np.float32) / 255.0 119 | composite_image = self.laplacian_pyramid_blending_with_mask(src_image, target_image, mask_image, num_levels) 120 | composite_image = np.clip(composite_image*255, 0 , 255).astype(np.uint8) 121 | composite_image=composite_image.transpose([1, 2, 0]) 122 | return composite_image 123 | 124 | 125 | def composite(bg_dir, fg_list, output_dir, masked_area_list, device="cuda"): 126 | bg_list = sorted(glob.glob( os.path.join(bg_dir ,"[0-9]*.png"), recursive=False)) 127 | 128 | blender = LaplacianPyramidBlender() 129 | 130 | for bg, fg_array, mask in tqdm(zip(bg_list, fg_list, masked_area_list),total=len(bg_list), desc="compositing"): 131 | name = Path(bg).name 132 | save_path = output_dir / name 133 | 134 | if fg_array is None: 135 | logger.info(f"composite fg_array is None -> skip") 136 | shutil.copy(bg, save_path) 137 | continue 138 | 139 | if mask is None: 140 | logger.info(f"mask is None -> skip") 141 | shutil.copy(bg, save_path) 142 | continue 143 | 144 | bg = np.asarray(Image.open(bg)).copy() 145 | fg = fg_array 146 | mask = np.concatenate([mask, mask, mask], 2) 147 | 148 | h, w, _ = bg.shape 149 | 150 | fg = cv2.resize(fg, dsize=(w,h)) 151 | mask = cv2.resize(mask, dsize=(w,h)) 152 | 153 | 154 | mask = mask.astype(np.float32) 155 | # mask = mask * 255 156 | mask = cv2.GaussianBlur(mask, (15, 15), 0) 157 | mask = mask / 255 158 | 159 | fg = fg * mask + bg * (1-mask) 160 | 161 | img = blender(fg, bg, mask,device) 162 | 163 | 164 | img = Image.fromarray(img) 165 | img.save(save_path) 166 | 167 | def simple_composite(bg_dir, fg_list, output_dir, masked_area_list, device="cuda"): 168 | bg_list = sorted(glob.glob( os.path.join(bg_dir ,"[0-9]*.png"), recursive=False)) 169 | 170 | for bg, fg_array, mask in tqdm(zip(bg_list, fg_list, masked_area_list),total=len(bg_list), desc="compositing"): 171 | name = Path(bg).name 172 | save_path = output_dir / name 173 | 174 | if fg_array is None: 175 | logger.info(f"composite fg_array is None -> skip") 176 | shutil.copy(bg, save_path) 177 | continue 178 | 179 | if mask is None: 180 | logger.info(f"mask is None -> skip") 181 | shutil.copy(bg, save_path) 182 | continue 183 | 184 | bg = np.asarray(Image.open(bg)).copy() 185 | fg = fg_array 186 | mask = np.concatenate([mask, mask, mask], 2) 187 | 188 | h, w, _ = bg.shape 189 | 190 | fg = cv2.resize(fg, dsize=(w,h)) 191 | mask = cv2.resize(mask, dsize=(w,h)) 192 | 193 | 194 | mask = mask.astype(np.float32) 195 | mask = cv2.GaussianBlur(mask, (15, 15), 0) 196 | mask = mask / 255 197 | 198 | img = fg * mask + bg * (1-mask) 199 | img = img.clip(0 , 255).astype(np.uint8) 200 | 201 | img = Image.fromarray(img) 202 | img.save(save_path) -------------------------------------------------------------------------------- /example.md: -------------------------------------------------------------------------------- 1 | ### Example 2 | 3 | - region prompt(txt2img / no controlnet) 4 | - region 0 ... 1girl, upper body etc 5 | - region 1 ... ((car)), street, road,no human etc 6 | - background ... town, outdoors etc 7 | - ip adapter input for background / region 0 / region 1 8 | 9 | 10 | - animatediff generate -c config/prompts/region_txt2img.json -W 512 -H 768 -L 32 -C 16 11 | - region 0 mask / region 1 mask / txt2img 12 | 13 |
14 | 15 | 16 | 17 |
18 |
19 | 20 | - apply different lora for each region. 21 | - [abdiel](https://civitai.com/models/159943/abdiel-shin-megami-tensei-v-v) for region 0 22 | - [amanozoko](https://civitai.com/models/159933/amanozoko-shin-megami-tensei-v-v) for region 1 23 | - no lora for background 24 |
25 | 26 | ```json 27 | # new lora_map format 28 | "lora_map": { 29 | # Specify lora as a path relative to /animatediff-cli/data 30 | "share/Lora/zs_Abdiel.safetensors": { # setting for abdiel lora 31 | "region" : ["0"], # target region. Multiple designations possible 32 | "scale" : { 33 | # "frame_no" : scale format 34 | "0": 0.75 # lora scale. same as prompt_map format. For example, it is possible to set the lora to be used from the 30th frame. 35 | } 36 | }, 37 | "share/Lora/zs_Amanazoko.safetensors": { # setting for amanozako lora 38 | "region" : ["1"], # target region 39 | "scale" : { 40 | "0": 0.75 41 | } 42 | } 43 | }, 44 | ``` 45 | - more example [here](https://github.com/s9roll7/animatediff-cli-prompt-travel/issues/147) 46 |
47 | 48 | 49 | 50 | - img2img 51 | - This can be improved using controlnet, but this sample does not use it. 52 | - source / denoising_strength 0.7 / denoising_strength 0.85 53 |
54 |
55 |
56 | 57 | - [A command to stylization with region has been added](https://github.com/s9roll7/animatediff-cli-prompt-travel#video-stylization-with-region). 58 | - (You can also create json manually without using the stylize command.) 59 | - region prompt 60 | - Region division into person shapes 61 | - source / img2img / txt2img 62 |
63 |
64 | 65 | - source / Region division into person shapes / inpaint 66 |
67 |
68 |
69 | 70 | 71 | 72 | 73 | - [A command to stylization with mask has been added](https://github.com/s9roll7/animatediff-cli-prompt-travel#video-stylization-with-mask). 74 | - more example [here](https://github.com/s9roll7/animatediff-cli-prompt-travel/issues/111) 75 | 76 |
77 |
78 | 79 | 80 | - [A command to automate video stylization has been added](https://github.com/s9roll7/animatediff-cli-prompt-travel#video-stylization). 81 | - Original / First generation result / Second generation(for upscaling) result 82 | - It took 4 minutes to generate the first one and about 5 minutes to generate the second one (on rtx 4090). 83 | - more example [here](https://github.com/s9roll7/animatediff-cli-prompt-travel/issues/29) 84 | 85 |
86 |
87 | 88 | 89 | - controlnet_openpose + controlnet_softedge 90 | - input frames for controlnet(0,16,32 frames) 91 | 92 | 93 | - result 94 |
95 |
96 | 97 | - In the latest version, generation can now be controlled more precisely through prompts. 98 | - sample 1 99 | ```json 100 | "prompt_fixed_ratio": 0.8, 101 | "head_prompt": "1girl, wizard, circlet, earrings, jewelry, purple hair,", 102 | "prompt_map": { 103 | "0": "(standing,full_body),blue_sky, town", 104 | "8": "(sitting,full_body),rain, town", 105 | "16": "(standing,full_body),blue_sky, woods", 106 | "24": "(upper_body), beach", 107 | "32": "(upper_body, smile)", 108 | "40": "(upper_body, angry)", 109 | "48": "(upper_body, smile, from_above)", 110 | "56": "(upper_body, angry, from_side)", 111 | "64": "(upper_body, smile, from_below)", 112 | "72": "(upper_body, angry, from_behind, looking at viewer)", 113 | "80": "face,looking at viewer", 114 | "88": "face,looking at viewer, closed_eyes", 115 | "96": "face,looking at viewer, open eyes, open_mouth", 116 | "104": "face,looking at viewer, closed_eyes, closed_mouth", 117 | "112": "face,looking at viewer, open eyes,eyes, open_mouth, tongue, smile, laughing", 118 | "120": "face,looking at viewer, eating, bowl,chopsticks,holding,food" 119 | }, 120 | ``` 121 |
122 |
123 | 124 | - sample 2 125 | ```json 126 | "prompt_fixed_ratio": 1.0, 127 | "head_prompt": "1girl, wizard, circlet, earrings, jewelry, purple hair,", 128 | "prompt_map": { 129 | "0": "", 130 | "8": "((fire magic spell, fire background))", 131 | "16": "((ice magic spell, ice background))", 132 | "24": "((thunder magic spell, thunder background))", 133 | "32": "((skull magic spell, skull background))", 134 | "40": "((wind magic spell, wind background))", 135 | "48": "((stone magic spell, stone background))", 136 | "56": "((holy magic spell, holy background))", 137 | "64": "((star magic spell, star background))", 138 | "72": "((plant magic spell, plant background))", 139 | "80": "((meteor magic spell, meteor background))" 140 | }, 141 | ``` 142 |
143 |
144 | 145 | -------------------------------------------------------------------------------- /src/animatediff/models/clip.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling 6 | from transformers.models.clip import CLIPPreTrainedModel, CLIPTextConfig, CLIPTextModel 7 | from transformers.models.clip.modeling_clip import ( 8 | CLIP_TEXT_INPUTS_DOCSTRING, 9 | CLIPTextTransformer, 10 | _expand_mask, 11 | _make_causal_mask, 12 | ) 13 | from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings 14 | 15 | CLIP_SKIP_TEXT_INPUTS_DOCSTRING = ( 16 | CLIP_TEXT_INPUTS_DOCSTRING 17 | + r""" 18 | clip_skip (`int`, *optional*, defaults to 1): 19 | Skip the final N layers of the CLIP text encoder. Some Diffusion models were trained 20 | using the hidden states from the 2nd-last layer of the CLIP text encoder (ie clip_skip=2), 21 | so we reproduce that behavior here for use with those models. 22 | """ 23 | ) 24 | 25 | 26 | class CLIPSkipTextTransformer(CLIPTextTransformer): 27 | @add_start_docstrings_to_model_forward(CLIP_SKIP_TEXT_INPUTS_DOCSTRING) 28 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) 29 | def forward( 30 | self, 31 | input_ids: Optional[torch.Tensor] = None, 32 | attention_mask: Optional[torch.Tensor] = None, 33 | position_ids: Optional[torch.Tensor] = None, 34 | output_attentions: Optional[bool] = None, 35 | output_hidden_states: Optional[bool] = None, 36 | return_dict: Optional[bool] = None, 37 | clip_skip: int = 1, 38 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 39 | r""" 40 | Returns: 41 | 42 | """ 43 | output_attentions = ( 44 | output_attentions if output_attentions is not None else self.config.output_attentions 45 | ) 46 | output_hidden_states = ( 47 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 48 | ) 49 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 50 | 51 | if input_ids is None: 52 | raise ValueError("You have to specify input_ids") 53 | 54 | input_shape = input_ids.size() 55 | input_ids = input_ids.view(-1, input_shape[-1]) 56 | 57 | hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) 58 | 59 | # CLIP's text model uses causal mask, prepare it here. 60 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 61 | causal_attention_mask = _make_causal_mask( 62 | input_shape, hidden_states.dtype, device=hidden_states.device 63 | ) 64 | # expand attention_mask 65 | if attention_mask is not None: 66 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 67 | attention_mask = _expand_mask(attention_mask, hidden_states.dtype) 68 | 69 | encoder_outputs: BaseModelOutput = self.encoder( 70 | inputs_embeds=hidden_states, 71 | attention_mask=attention_mask, 72 | causal_attention_mask=causal_attention_mask, 73 | output_attentions=output_attentions, 74 | output_hidden_states=True, 75 | return_dict=True, 76 | ) 77 | 78 | # take the hidden state from the Nth-to-last layer of the encoder, where N = clip_skip 79 | # clip_skip=1 means take the hidden state from the last layer as with CLIPTextTransformer 80 | last_hidden_state = encoder_outputs.hidden_states[-clip_skip] 81 | last_hidden_state = self.final_layer_norm(last_hidden_state) 82 | 83 | # text_embeds.shape = [batch_size, sequence_length, transformer.width] 84 | # take features from the eot embedding (eot_token is the highest number in each sequence) 85 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 86 | pooled_output = last_hidden_state[ 87 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), 88 | input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), 89 | ] 90 | 91 | if not return_dict: 92 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 93 | 94 | return BaseModelOutputWithPooling( 95 | last_hidden_state=last_hidden_state, 96 | pooler_output=pooled_output, 97 | hidden_states=encoder_outputs.hidden_states, 98 | attentions=encoder_outputs.attentions, 99 | ) 100 | 101 | def _build_causal_attention_mask(self, bsz, seq_len, dtype): 102 | # lazily create causal attention mask, with full attention between the vision tokens 103 | # pytorch uses additive attention mask; fill with -inf 104 | mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) 105 | mask.fill_(torch.tensor(torch.finfo(dtype).min)) 106 | mask.triu_(1) # zero out the lower diagonal 107 | mask = mask.unsqueeze(1) # expand mask 108 | return mask 109 | 110 | 111 | class CLIPSkipTextModel(CLIPTextModel): 112 | config_class = CLIPTextConfig 113 | 114 | _no_split_modules = ["CLIPEncoderLayer"] 115 | 116 | def __init__(self, config: CLIPTextConfig): 117 | super().__init__(config) 118 | self.text_model = CLIPSkipTextTransformer(config) 119 | # Initialize weights and apply final processing 120 | self.post_init() 121 | 122 | @add_start_docstrings_to_model_forward(CLIP_SKIP_TEXT_INPUTS_DOCSTRING) 123 | @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) 124 | def forward( 125 | self, 126 | input_ids: Optional[torch.Tensor] = None, 127 | attention_mask: Optional[torch.Tensor] = None, 128 | position_ids: Optional[torch.Tensor] = None, 129 | output_attentions: Optional[bool] = None, 130 | output_hidden_states: Optional[bool] = None, 131 | return_dict: Optional[bool] = None, 132 | clip_skip: int = 1, 133 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 134 | r""" 135 | Returns: 136 | 137 | Examples: 138 | 139 | ```python 140 | >>> from transformers import AutoTokenizer, CLIPSkipTextModel 141 | 142 | >>> model = CLIPSkipTextModel.from_pretrained("openai/clip-vit-base-patch32") 143 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") 144 | 145 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 146 | 147 | >>> outputs = model(**inputs) 148 | >>> last_hidden_state = outputs.last_hidden_state 149 | >>> pooled_output = outputs.pooler_output # pooled (EOS token) states 150 | ```""" 151 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 152 | 153 | return self.text_model( 154 | input_ids=input_ids, 155 | attention_mask=attention_mask, 156 | position_ids=position_ids, 157 | output_attentions=output_attentions, 158 | output_hidden_states=output_hidden_states, 159 | return_dict=return_dict, 160 | clip_skip=clip_skip, 161 | ) 162 | -------------------------------------------------------------------------------- /src/animatediff/rife/rife.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import subprocess 3 | from math import ceil 4 | from pathlib import Path 5 | from typing import Annotated, Optional 6 | 7 | import typer 8 | 9 | from animatediff import get_dir 10 | 11 | from .ffmpeg import FfmpegEncoder, VideoCodec, codec_extn 12 | from .ncnn import RifeNCNNOptions 13 | 14 | rife_dir = get_dir("data/rife") 15 | rife_ncnn_vulkan = rife_dir.joinpath("rife-ncnn-vulkan") 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | app: typer.Typer = typer.Typer( 20 | name="rife", 21 | context_settings=dict(help_option_names=["-h", "--help"]), 22 | rich_markup_mode="rich", 23 | pretty_exceptions_show_locals=False, 24 | help="RIFE motion flow interpolation (MORE FPS!)", 25 | ) 26 | 27 | def rife_interpolate( 28 | input_frames_dir:str, 29 | output_frames_dir:str, 30 | frame_multiplier:int = 2, 31 | rife_model:str = "rife-v4.6", 32 | spatial_tta:bool = False, 33 | temporal_tta:bool = False, 34 | uhd:bool = False, 35 | ): 36 | 37 | rife_model_dir = rife_dir.joinpath(rife_model) 38 | if not rife_model_dir.joinpath("flownet.bin").exists(): 39 | raise FileNotFoundError(f"RIFE model dir {rife_model_dir} does not have a model in it!") 40 | 41 | 42 | rife_opts = RifeNCNNOptions( 43 | model_path=rife_model_dir, 44 | input_path=input_frames_dir, 45 | output_path=output_frames_dir, 46 | time_step=1 / frame_multiplier, 47 | spatial_tta=spatial_tta, 48 | temporal_tta=temporal_tta, 49 | uhd=uhd, 50 | ) 51 | rife_args = rife_opts.get_args(frame_multiplier=frame_multiplier) 52 | 53 | # actually run RIFE 54 | logger.info("Running RIFE, this may take a little while...") 55 | with subprocess.Popen( 56 | [rife_ncnn_vulkan, *rife_args], stdout=subprocess.PIPE, stderr=subprocess.PIPE 57 | ) as proc: 58 | errs = [] 59 | for line in proc.stderr: 60 | line = line.decode("utf-8").strip() 61 | if line: 62 | logger.debug(line) 63 | stdout, _ = proc.communicate() 64 | if proc.returncode != 0: 65 | raise RuntimeError(f"RIFE failed with code {proc.returncode}:\n" + "\n".join(errs)) 66 | 67 | import glob 68 | import os 69 | org_images = sorted(glob.glob( os.path.join(output_frames_dir, "[0-9]*.png"), recursive=False)) 70 | for o in org_images: 71 | p = Path(o) 72 | new_no = int(p.stem) - 1 73 | new_p = p.with_stem(f"{new_no:08d}") 74 | p.rename(new_p) 75 | 76 | 77 | 78 | @app.command(no_args_is_help=True) 79 | def interpolate( 80 | rife_model: Annotated[ 81 | str, 82 | typer.Option("--rife-model", "-m", help="RIFE model to use (subdirectory of data/rife/)"), 83 | ] = "rife-v4.6", 84 | in_fps: Annotated[ 85 | int, 86 | typer.Option("--in-fps", "-I", help="Input frame FPS (8 for AnimateDiff)", show_default=True), 87 | ] = 8, 88 | frame_multiplier: Annotated[ 89 | int, 90 | typer.Option( 91 | "--frame-multiplier", "-M", help="Multiply total frame count by this", show_default=True 92 | ), 93 | ] = 8, 94 | out_fps: Annotated[ 95 | int, 96 | typer.Option("--out-fps", "-F", help="Target FPS", show_default=True), 97 | ] = 50, 98 | codec: Annotated[ 99 | VideoCodec, 100 | typer.Option("--codec", "-c", help="Output video codec", show_default=True), 101 | ] = VideoCodec.webm, 102 | lossless: Annotated[ 103 | bool, 104 | typer.Option("--lossless", "-L", is_flag=True, help="Use lossless encoding (WebP only)"), 105 | ] = False, 106 | spatial_tta: Annotated[ 107 | bool, 108 | typer.Option("--spatial-tta", "-x", is_flag=True, help="Enable RIFE Spatial TTA mode"), 109 | ] = False, 110 | temporal_tta: Annotated[ 111 | bool, 112 | typer.Option("--temporal-tta", "-z", is_flag=True, help="Enable RIFE Temporal TTA mode"), 113 | ] = False, 114 | uhd: Annotated[ 115 | bool, 116 | typer.Option("--uhd", "-u", is_flag=True, help="Enable RIFE UHD mode"), 117 | ] = False, 118 | frames_dir: Annotated[ 119 | Path, 120 | typer.Argument(path_type=Path, file_okay=False, exists=True, help="Path to source frames directory"), 121 | ] = ..., 122 | out_file: Annotated[ 123 | Optional[Path], 124 | typer.Argument( 125 | dir_okay=False, 126 | help="Path to output file (default: frames_dir/rife-output.)", 127 | show_default=False, 128 | ), 129 | ] = None, 130 | ): 131 | rife_model_dir = rife_dir.joinpath(rife_model) 132 | if not rife_model_dir.joinpath("flownet.bin").exists(): 133 | raise FileNotFoundError(f"RIFE model dir {rife_model_dir} does not have a model in it!") 134 | 135 | if not frames_dir.exists(): 136 | raise FileNotFoundError(f"Frames directory {frames_dir} does not exist!") 137 | 138 | # where to put the RIFE interpolated frames (default: frames_dir/../-rife) 139 | # TODO: make this configurable? 140 | rife_frames_dir = frames_dir.parent.joinpath(f"{frames_dir.name}-rife") 141 | rife_frames_dir.mkdir(exist_ok=True, parents=True) 142 | 143 | # build output file path 144 | file_extn = codec_extn(codec) 145 | if out_file is None: 146 | out_file = frames_dir.parent.joinpath(f"{frames_dir.name}-rife.{file_extn}") 147 | elif out_file.suffix != file_extn: 148 | logger.warn("Output file extension does not match codec, changing extension") 149 | out_file = out_file.with_suffix(file_extn) 150 | 151 | # build RIFE command and get args 152 | # This doesn't need to be a Pydantic model tbh. It could just be a function/class. 153 | rife_opts = RifeNCNNOptions( 154 | model_path=rife_model_dir, 155 | input_path=frames_dir, 156 | output_path=rife_frames_dir, 157 | time_step=1 / in_fps, # TODO: make this configurable? 158 | spatial_tta=spatial_tta, 159 | temporal_tta=temporal_tta, 160 | uhd=uhd, 161 | ) 162 | rife_args = rife_opts.get_args(frame_multiplier=frame_multiplier) 163 | 164 | # actually run RIFE 165 | logger.info("Running RIFE, this may take a little while...") 166 | with subprocess.Popen( 167 | [rife_ncnn_vulkan, *rife_args], stdout=subprocess.PIPE, stderr=subprocess.PIPE 168 | ) as proc: 169 | errs = [] 170 | for line in proc.stderr: 171 | line = line.decode("utf-8").strip() 172 | if line: 173 | logger.debug(line) 174 | stdout, _ = proc.communicate() 175 | if proc.returncode != 0: 176 | raise RuntimeError(f"RIFE failed with code {proc.returncode}:\n" + "\n".join(errs)) 177 | 178 | # now it is ffmpeg time 179 | logger.info("Creating ffmpeg encoder...") 180 | encoder = FfmpegEncoder( 181 | frames_dir=rife_frames_dir, 182 | out_file=out_file, 183 | codec=codec, 184 | in_fps=min(out_fps, in_fps * frame_multiplier), 185 | out_fps=out_fps, 186 | lossless=lossless, 187 | ) 188 | logger.info("Encoding interpolated frames with ffmpeg...") 189 | result = encoder.encode() 190 | 191 | logger.debug(f"ffmpeg result: {result}") 192 | 193 | logger.info(f"Find the RIFE frames at: {rife_frames_dir.absolute().relative_to(Path.cwd())}") 194 | logger.info(f"Find the output file at: {out_file.absolute().relative_to(Path.cwd())}") 195 | logger.info("Done!") 196 | -------------------------------------------------------------------------------- /src/animatediff/pipelines/ti.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | import torch 6 | from diffusers import DiffusionPipeline 7 | from safetensors.torch import load_file 8 | from torch import Tensor 9 | 10 | from animatediff import get_dir 11 | 12 | EMBED_DIR = get_dir("data").joinpath("embeddings") 13 | EMBED_DIR_SDXL = get_dir("data").joinpath("sdxl_embeddings") 14 | EMBED_EXTS = [".pt", ".pth", ".bin", ".safetensors"] 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def scan_text_embeddings(is_sdxl=False) -> list[Path]: 20 | embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR 21 | return [x for x in embed_dir.rglob("**/*") if x.is_file() and x.suffix.lower() in EMBED_EXTS] 22 | 23 | 24 | def get_text_embeddings(return_tensors: bool = True, is_sdxl:bool = False) -> dict[str, Union[Tensor, Path]]: 25 | embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR 26 | embeds = {} 27 | skipped = {} 28 | path: Path 29 | for path in scan_text_embeddings(is_sdxl): 30 | if path.stem not in embeds: 31 | # new token/name, add it 32 | logger.debug(f"Found embedding token {path.stem} at {path.relative_to(embed_dir)}") 33 | embeds[path.stem] = path 34 | else: 35 | # duplicate token/name, skip it 36 | skipped[path.stem] = path 37 | logger.debug(f"Duplicate embedding token {path.stem} at {path.relative_to(embed_dir)}") 38 | 39 | # warn the user if there are duplicates we skipped 40 | if skipped: 41 | logger.warn(f"Skipped {len(skipped)} embeddings with duplicate tokens!") 42 | logger.warn(f"Skipped paths: {[x.relative_to(embed_dir) for x in skipped.values()]}") 43 | logger.warn("Rename these files to avoid collisions!") 44 | 45 | # we can optionally return the tensors instead of the paths 46 | if return_tensors: 47 | # load the embeddings 48 | embeds = {k: load_embed_weights(v) for k, v in embeds.items()} 49 | # filter out the ones that failed to load 50 | loaded_embeds = {k: v for k, v in embeds.items() if v is not None} 51 | if len(loaded_embeds) != len(embeds): 52 | logger.warn(f"Failed to load {len(embeds) - len(loaded_embeds)} embeddings!") 53 | logger.warn(f"Skipped embeddings: {[x for x in embeds.keys() if x not in loaded_embeds]}") 54 | 55 | # return a dict of {token: path | embedding} 56 | return embeds 57 | 58 | 59 | def load_embed_weights(path: Path, key: Optional[str] = None) -> Optional[Tensor]: 60 | """Load an embedding from a file. 61 | Accepts an optional key to load a specific embedding from a file with multiple embeddings, otherwise 62 | it will try to load the first one it finds. 63 | """ 64 | if not path.exists() and path.is_file(): 65 | raise ValueError(f"Embedding path {path} does not exist or is not a file!") 66 | try: 67 | if path.suffix.lower() == ".safetensors": 68 | state_dict = load_file(path, device="cpu") 69 | elif path.suffix.lower() in EMBED_EXTS: 70 | state_dict = torch.load(path, weights_only=True, map_location="cpu") 71 | except Exception: 72 | logger.error(f"Failed to load embedding {path}", exc_info=True) 73 | return None 74 | 75 | embedding = None 76 | if len(state_dict) == 1: 77 | logger.debug(f"Found single key in {path.stem}, using it") 78 | embedding = next(iter(state_dict.values())) 79 | elif key is not None and key in state_dict: 80 | logger.debug(f"Using passed key {key} for {path.stem}") 81 | embedding = state_dict[key] 82 | elif "string_to_param" in state_dict: 83 | logger.debug(f"A1111 style embedding found for {path.stem}") 84 | embedding = next(iter(state_dict["string_to_param"].values())) 85 | else: 86 | # we couldn't find the embedding key, warn the user and just use the first key that's a Tensor 87 | logger.warn(f"Could not find embedding key in {path.stem}!") 88 | logger.warn("Taking a wild guess and using the first Tensor we find...") 89 | for key, value in state_dict.items(): 90 | if torch.is_tensor(value): 91 | embedding = value 92 | logger.warn(f"Using key: {key}") 93 | break 94 | 95 | return embedding 96 | 97 | 98 | def load_text_embeddings( 99 | pipeline: DiffusionPipeline, text_embeds: Optional[tuple[str, torch.Tensor]] = None, is_sdxl = False 100 | ) -> None: 101 | if text_embeds is None: 102 | text_embeds = get_text_embeddings(False, is_sdxl) 103 | if len(text_embeds) < 1: 104 | logger.info("No TI embeddings found") 105 | return 106 | 107 | logger.info(f"Loading {len(text_embeds)} TI embeddings...") 108 | loaded, skipped, failed = [], [], [] 109 | 110 | if True: 111 | vocab = pipeline.tokenizer.get_vocab() # get the tokenizer vocab so we can skip loaded embeddings 112 | for token, emb_path in text_embeds.items(): 113 | try: 114 | if token not in vocab: 115 | if is_sdxl: 116 | embed = load_embed_weights(emb_path, "clip_g").to(pipeline.text_encoder_2.device) 117 | pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) 118 | embed = load_embed_weights(emb_path, "clip_l").to(pipeline.text_encoder.device) 119 | pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) 120 | else: 121 | embed = load_embed_weights(emb_path).to(pipeline.text_encoder.device) 122 | pipeline.load_textual_inversion({token: embed}) 123 | logger.debug(f"Loaded embedding '{token}'") 124 | loaded.append(token) 125 | else: 126 | logger.debug(f"Skipping embedding '{token}' (already loaded)") 127 | skipped.append(token) 128 | except Exception: 129 | logger.error(f"Failed to load TI embedding: {token}", exc_info=True) 130 | failed.append(token) 131 | 132 | else: 133 | vocab = pipeline.tokenizer.get_vocab() # get the tokenizer vocab so we can skip loaded embeddings 134 | for token, embed in text_embeds.items(): 135 | try: 136 | if token not in vocab: 137 | if is_sdxl: 138 | pipeline.load_textual_inversion(text_encoder_sd, token=token, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) 139 | else: 140 | pipeline.load_textual_inversion({token: embed}) 141 | logger.debug(f"Loaded embedding '{token}'") 142 | loaded.append(token) 143 | else: 144 | logger.debug(f"Skipping embedding '{token}' (already loaded)") 145 | skipped.append(token) 146 | except Exception: 147 | logger.error(f"Failed to load TI embedding: {token}", exc_info=True) 148 | failed.append(token) 149 | 150 | # Print a summary of what we loaded 151 | logger.info(f"Loaded {len(loaded)} embeddings, {len(skipped)} existing, {len(failed)} failed") 152 | logger.info(f"Available embeddings: {', '.join(loaded + skipped)}") 153 | if len(failed) > 0: 154 | # only print failed if there were failures 155 | logger.warn(f"Failed to load embeddings: {', '.join(failed)}") 156 | -------------------------------------------------------------------------------- /src/animatediff/rife/ffmpeg.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from pathlib import Path 3 | from re import split 4 | from typing import Annotated, Optional, Union 5 | 6 | import ffmpeg 7 | from ffmpeg.nodes import FilterNode, InputNode 8 | from torch import Value 9 | 10 | 11 | class VideoCodec(str, Enum): 12 | gif = "gif" 13 | vp9 = "vp9" 14 | webm = "webm" 15 | webp = "webp" 16 | h264 = "h264" 17 | hevc = "hevc" 18 | 19 | 20 | def codec_extn(codec: VideoCodec): 21 | match codec: 22 | case VideoCodec.gif: 23 | return "gif" 24 | case VideoCodec.vp9: 25 | return "webm" 26 | case VideoCodec.webm: 27 | return "webm" 28 | case VideoCodec.webp: 29 | return "webp" 30 | case VideoCodec.h264: 31 | return "mp4" 32 | case VideoCodec.hevc: 33 | return "mp4" 34 | case _: 35 | raise ValueError(f"Unknown codec {codec}") 36 | 37 | 38 | def clamp_gif_fps(fps: int): 39 | """Clamp FPS to a value that is supported by GIFs. 40 | 41 | GIF frame duration is measured in 1/100ths of a second, so we need to clamp the 42 | FPS to a value that 100 is a factor of. 43 | """ 44 | # the sky is not the limit, sadly... 45 | if fps > 100: 46 | return 100 47 | 48 | # if 100/fps is an integer, we're good 49 | if 100 % fps == 0: 50 | return fps 51 | 52 | # but of course, it was never going to be that easy. 53 | match fps: 54 | case x if x > 50: 55 | # 50 is the highest FPS that 100 is a factor of. 56 | # people will ask for 60. they will get 50, and they will like it. 57 | return 50 58 | case x if x >= 30: 59 | return 33 60 | case x if x >= 24: 61 | return 25 62 | case x if x >= 20: 63 | return 20 64 | case x if x >= 15: 65 | # ffmpeg will pad a few frames to make this work 66 | return 16 67 | case x if x >= 12: 68 | return 12 69 | case x if x >= 10: 70 | # idk why anyone would request 11fps, but they're getting 10 71 | return 10 72 | case x if x >= 6: 73 | # also invalid but ffmpeg will pad it 74 | return 6 75 | case 4: 76 | return 4 # FINE, I GUESS 77 | case _: 78 | return 1 # I don't know why you would want this, but here you go 79 | 80 | 81 | class FfmpegEncoder: 82 | def __init__( 83 | self, 84 | frames_dir: Path, 85 | out_file: Path, 86 | codec: VideoCodec, 87 | in_fps: int = 60, 88 | out_fps: int = 60, 89 | lossless: bool = False, 90 | param={}, 91 | ): 92 | self.frames_dir = frames_dir 93 | self.out_file = out_file 94 | self.codec = codec 95 | self.in_fps = in_fps 96 | self.out_fps = out_fps 97 | self.lossless = lossless 98 | self.param = param 99 | 100 | self.input: Optional[InputNode] = None 101 | 102 | def encode(self) -> tuple: 103 | self.input: InputNode = ffmpeg.input( 104 | str(self.frames_dir.resolve().joinpath("%08d.png")), framerate=self.in_fps 105 | ).filter("fps", fps=self.in_fps) 106 | match self.codec: 107 | case VideoCodec.gif: 108 | return self._encode_gif() 109 | case VideoCodec.webm: 110 | return self._encode_webm() 111 | case VideoCodec.webp: 112 | return self._encode_webp() 113 | case VideoCodec.h264: 114 | return self._encode_h264() 115 | case VideoCodec.hevc: 116 | return self._encode_hevc() 117 | case _: 118 | raise ValueError(f"Unknown codec {self.codec}") 119 | 120 | @property 121 | def _out_file(self) -> Path: 122 | return str(self.out_file.resolve()) 123 | 124 | @staticmethod 125 | def _interpolate(stream, out_fps: int) -> FilterNode: 126 | return stream.filter( 127 | "minterpolate", fps=out_fps, mi_mode="mci", mc_mode="aobmc", me_mode="bidir", vsbmc=1 128 | ) 129 | 130 | def _encode_gif(self) -> tuple: 131 | stream: FilterNode = self.input 132 | 133 | # Output FPS must be divisible by 100 for GIFs, so we clamp it 134 | out_fps = clamp_gif_fps(self.out_fps) 135 | if self.in_fps != out_fps: 136 | stream = self._interpolate(stream, out_fps) 137 | 138 | # split into two streams for palettegen and paletteuse 139 | split_stream = stream.split() 140 | 141 | # generate the palette, then use it to encode the GIF 142 | palette = split_stream[0].filter("palettegen") 143 | stream = ffmpeg.filter([split_stream[1], palette], "paletteuse").output( 144 | self._out_file, vcodec="gif", loop=0 145 | ) 146 | return stream.run() 147 | 148 | def _encode_webm(self) -> tuple: 149 | stream: FilterNode = self.input 150 | if self.in_fps != self.out_fps: 151 | stream = self._interpolate(stream, self.out_fps) 152 | param = { 153 | "pix_fmt":"yuv420p", 154 | "vcodec":"libvpx-vp9", 155 | "video_bitrate":0, 156 | "crf":24, 157 | } 158 | param.update(**self.param) 159 | stream = stream.output( 160 | self._out_file, **param 161 | ) 162 | return stream.run() 163 | 164 | def _encode_webp(self) -> tuple: 165 | stream: FilterNode = self.input 166 | if self.in_fps != self.out_fps: 167 | stream = self._interpolate(stream, self.out_fps) 168 | 169 | if self.lossless: 170 | param = { 171 | "pix_fmt":"bgra", 172 | "vcodec":"libwebp_anim", 173 | "lossless":1, 174 | "compression_level":5, 175 | "qscale":75, 176 | "loop":0, 177 | } 178 | param.update(**self.param) 179 | stream = stream.output( 180 | self._out_file, 181 | **param 182 | ) 183 | else: 184 | param = { 185 | "pix_fmt":"yuv420p", 186 | "vcodec":"libwebp_anim", 187 | "lossless":0, 188 | "compression_level":5, 189 | "qscale":90, 190 | "loop":0, 191 | } 192 | param.update(**self.param) 193 | stream = stream.output( 194 | self._out_file, 195 | **param 196 | ) 197 | return stream.run() 198 | 199 | def _encode_h264(self) -> tuple: 200 | stream: FilterNode = self.input 201 | if self.in_fps != self.out_fps: 202 | stream = self._interpolate(stream, self.out_fps) 203 | 204 | param = { 205 | "pix_fmt":"yuv420p", 206 | "vcodec":"libx264", 207 | "crf":21, 208 | "tune":"animation", 209 | } 210 | param.update(**self.param) 211 | 212 | stream = stream.output( 213 | self._out_file, **param 214 | ) 215 | return stream.run() 216 | 217 | def _encode_hevc(self) -> tuple: 218 | stream: FilterNode = self.input 219 | if self.in_fps != self.out_fps: 220 | stream = self._interpolate(stream, self.out_fps) 221 | 222 | param = { 223 | "pix_fmt":"yuv420p", 224 | "vcodec":"libx264", 225 | "crf":21, 226 | "tune":"animation", 227 | } 228 | param.update(**self.param) 229 | 230 | stream = stream.output(self._out_file, **param) 231 | return stream.run() 232 | -------------------------------------------------------------------------------- /config/prompts/prompt_travel_multi_controlnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sample", 3 | "path": "share/Stable-diffusion/mistoonAnime_v20.safetensors", 4 | "motion_module": "models/motion-module/mm_sd_v14.ckpt", 5 | "compile": false, 6 | "seed": [ 7 | 341774366206100 8 | ], 9 | "scheduler": "k_dpmpp_sde", 10 | "steps": 20, 11 | "guidance_scale": 10, 12 | "clip_skip": 2, 13 | "head_prompt": "masterpiece, best quality, a beautiful and detailed portriat of muffet, monster girl,((purple body:1.3)),humanoid, arachnid, anthro,((fangs)),pigtails,hair bows,5 eyes,spider girl,6 arms,solo", 14 | "prompt_map": { 15 | "0": "smile standing,((spider webs:1.0))", 16 | "32": "(((walking))),((spider webs:1.0))", 17 | "64": "(((running))),((spider webs:2.0)),wide angle lens, fish eye effect", 18 | "96": "(((sitting))),((spider webs:1.0))" 19 | }, 20 | "tail_prompt": "clothed, open mouth, awesome and detailed background, holding teapot, holding teacup, 6 hands,detailed hands,storefront that sells pastries and tea,bloomers,(red and black clothing),inside,pouring into teacup,muffetwear", 21 | "n_prompt": [ 22 | "(worst quality, low quality:1.4),nudity,simple background,border,mouth closed,text, patreon,bed,bedroom,white background,((monochrome)),sketch,(pink body:1.4),7 arms,8 arms,4 arms" 23 | ], 24 | "lora_map": { 25 | "share/Lora/muffet_v2.safetensors" : 1.0, 26 | "share/Lora/add_detail.safetensors" : 1.0 27 | }, 28 | "ip_adapter_map": { 29 | "enable": true, 30 | "input_image_dir": "ip_adapter_image/test", 31 | "save_input_image": true, 32 | "resized_to_square": false, 33 | "scale": 0.5, 34 | "is_plus_face": true, 35 | "is_plus": true 36 | }, 37 | "controlnet_map": { 38 | "input_image_dir" : "controlnet_image/test", 39 | "max_samples_on_vram": 200, 40 | "max_models_on_vram" : 3, 41 | "save_detectmap": true, 42 | "preprocess_on_gpu": true, 43 | "is_loop": true, 44 | 45 | "controlnet_tile":{ 46 | "enable": true, 47 | "use_preprocessor":true, 48 | "preprocessor":{ 49 | "type" : "none", 50 | "param":{ 51 | } 52 | }, 53 | "guess_mode":false, 54 | "controlnet_conditioning_scale": 1.0, 55 | "control_guidance_start": 0.0, 56 | "control_guidance_end": 1.0, 57 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 58 | }, 59 | "controlnet_ip2p":{ 60 | "enable": true, 61 | "use_preprocessor":true, 62 | "guess_mode":false, 63 | "controlnet_conditioning_scale": 1.0, 64 | "control_guidance_start": 0.0, 65 | "control_guidance_end": 1.0, 66 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 67 | }, 68 | "controlnet_lineart_anime":{ 69 | "enable": true, 70 | "use_preprocessor":true, 71 | "guess_mode":false, 72 | "controlnet_conditioning_scale": 1.0, 73 | "control_guidance_start": 0.0, 74 | "control_guidance_end": 1.0, 75 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 76 | }, 77 | "controlnet_openpose":{ 78 | "enable": true, 79 | "use_preprocessor":true, 80 | "guess_mode":false, 81 | "controlnet_conditioning_scale": 1.0, 82 | "control_guidance_start": 0.0, 83 | "control_guidance_end": 1.0, 84 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 85 | }, 86 | "controlnet_softedge":{ 87 | "enable": true, 88 | "use_preprocessor":true, 89 | "preprocessor":{ 90 | "type" : "softedge_pidsafe", 91 | "param":{ 92 | } 93 | }, 94 | "guess_mode":false, 95 | "controlnet_conditioning_scale": 1.0, 96 | "control_guidance_start": 0.0, 97 | "control_guidance_end": 1.0, 98 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 99 | }, 100 | "controlnet_shuffle": { 101 | "enable": true, 102 | "use_preprocessor":true, 103 | "guess_mode":false, 104 | "controlnet_conditioning_scale": 1.0, 105 | "control_guidance_start": 0.0, 106 | "control_guidance_end": 1.0, 107 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 108 | }, 109 | "controlnet_depth": { 110 | "enable": true, 111 | "use_preprocessor":true, 112 | "guess_mode":false, 113 | "controlnet_conditioning_scale": 1.0, 114 | "control_guidance_start": 0.0, 115 | "control_guidance_end": 1.0, 116 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 117 | }, 118 | "controlnet_canny": { 119 | "enable": true, 120 | "use_preprocessor":true, 121 | "guess_mode":false, 122 | "controlnet_conditioning_scale": 1.0, 123 | "control_guidance_start": 0.0, 124 | "control_guidance_end": 1.0, 125 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 126 | }, 127 | "controlnet_inpaint": { 128 | "enable": true, 129 | "use_preprocessor":true, 130 | "guess_mode":false, 131 | "controlnet_conditioning_scale": 1.0, 132 | "control_guidance_start": 0.0, 133 | "control_guidance_end": 1.0, 134 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 135 | }, 136 | "controlnet_lineart": { 137 | "enable": true, 138 | "use_preprocessor":true, 139 | "guess_mode":false, 140 | "controlnet_conditioning_scale": 1.0, 141 | "control_guidance_start": 0.0, 142 | "control_guidance_end": 1.0, 143 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 144 | }, 145 | "controlnet_mlsd": { 146 | "enable": true, 147 | "use_preprocessor":true, 148 | "guess_mode":false, 149 | "controlnet_conditioning_scale": 1.0, 150 | "control_guidance_start": 0.0, 151 | "control_guidance_end": 1.0, 152 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 153 | }, 154 | "controlnet_normalbae": { 155 | "enable": true, 156 | "use_preprocessor":true, 157 | "guess_mode":false, 158 | "controlnet_conditioning_scale": 1.0, 159 | "control_guidance_start": 0.0, 160 | "control_guidance_end": 1.0, 161 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 162 | }, 163 | "controlnet_scribble": { 164 | "enable": true, 165 | "use_preprocessor":true, 166 | "guess_mode":false, 167 | "controlnet_conditioning_scale": 1.0, 168 | "control_guidance_start": 0.0, 169 | "control_guidance_end": 1.0, 170 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 171 | }, 172 | "controlnet_seg": { 173 | "enable": true, 174 | "use_preprocessor":true, 175 | "guess_mode":false, 176 | "controlnet_conditioning_scale": 1.0, 177 | "control_guidance_start": 0.0, 178 | "control_guidance_end": 1.0, 179 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 180 | }, 181 | "controlnet_ref": { 182 | "enable": false, 183 | "ref_image": "ref_image/ref_sample.png", 184 | "attention_auto_machine_weight": 0.3, 185 | "gn_auto_machine_weight": 0.3, 186 | "style_fidelity": 0.5, 187 | "reference_attn": true, 188 | "reference_adain": false, 189 | "scale_pattern":[1.0] 190 | } 191 | }, 192 | "upscale_config": { 193 | "scheduler": "k_dpmpp_sde", 194 | "steps": 20, 195 | "strength": 0.5, 196 | "guidance_scale": 10, 197 | "controlnet_tile": { 198 | "enable": true, 199 | "controlnet_conditioning_scale": 1.0, 200 | "guess_mode": false, 201 | "control_guidance_start": 0.0, 202 | "control_guidance_end": 1.0 203 | }, 204 | "controlnet_line_anime": { 205 | "enable": false, 206 | "controlnet_conditioning_scale": 1.0, 207 | "guess_mode": false, 208 | "control_guidance_start": 0.0, 209 | "control_guidance_end": 1.0 210 | }, 211 | "controlnet_ip2p": { 212 | "enable": true, 213 | "controlnet_conditioning_scale": 0.5, 214 | "guess_mode": false, 215 | "control_guidance_start": 0.0, 216 | "control_guidance_end": 1.0 217 | }, 218 | "controlnet_ref": { 219 | "enable": false, 220 | "use_frame_as_ref_image": false, 221 | "use_1st_frame_as_ref_image": true, 222 | "ref_image": "ref_image/path_to_your_ref_img.jpg", 223 | "attention_auto_machine_weight": 1.0, 224 | "gn_auto_machine_weight": 1.0, 225 | "style_fidelity": 0.25, 226 | "reference_attn": true, 227 | "reference_adain": false 228 | } 229 | }, 230 | "output":{ 231 | "preview_steps": [10], 232 | "format" : "gif", 233 | "fps" : 8, 234 | "encode_param":{ 235 | "crf": 10 236 | } 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /src/animatediff/pipelines/lora.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from safetensors.torch import load_file 4 | 5 | from animatediff import get_dir 6 | from animatediff.utils.lora_diffusers import (LoRANetwork, 7 | create_network_from_weights) 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | data_dir = get_dir("data") 12 | 13 | 14 | def merge_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): 15 | 16 | def dump(loaded): 17 | for a in loaded: 18 | logger.info(f"{a} {loaded[a].shape}") 19 | 20 | sd = load_file(lora_path) 21 | 22 | if False: 23 | dump(sd) 24 | 25 | print(f"create LoRA network") 26 | lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) 27 | print(f"load LoRA network weights") 28 | lora_network.load_state_dict(sd, False) 29 | lora_network.merge_to(alpha) 30 | 31 | def load_lora_map(pipe, lora_map_config, video_length, is_sdxl=False): 32 | new_map = {} 33 | for item in lora_map_config: 34 | lora_path = data_dir.joinpath(item) 35 | if type(lora_map_config[item]) in (float,int): 36 | te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder 37 | merge_safetensors_lora(te_en, pipe.unet, lora_path, lora_map_config[item], not is_sdxl) 38 | else: 39 | new_map[lora_path] = lora_map_config[item] 40 | 41 | lora_map = LoraMap(pipe, new_map, video_length, is_sdxl) 42 | pipe.lora_map = lora_map if lora_map.is_valid else None 43 | 44 | def load_lcm_lora(pipe, lcm_map, is_sdxl=False, is_merge=False): 45 | if is_sdxl: 46 | lora_path = data_dir.joinpath("models/lcm_lora/sdxl/pytorch_lora_weights.safetensors") 47 | else: 48 | lora_path = data_dir.joinpath("models/lcm_lora/sd15/pytorch_lora_weights.safetensors") 49 | logger.info(f"{lora_path=}") 50 | 51 | if is_merge: 52 | te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder 53 | merge_safetensors_lora(te_en, pipe.unet, lora_path, 1.0, not is_sdxl) 54 | pipe.lcm = None 55 | return 56 | 57 | lcm = LcmLora(pipe, is_sdxl, lora_path, lcm_map) 58 | pipe.lcm = lcm if lcm.is_valid else None 59 | 60 | class LcmLora: 61 | def __init__( 62 | self, 63 | pipe, 64 | is_sdxl, 65 | lora_path, 66 | lcm_map 67 | ): 68 | self.is_valid = False 69 | 70 | sd = load_file(lora_path) 71 | if not sd: 72 | return 73 | 74 | te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder 75 | lora_network: LoRANetwork = create_network_from_weights(te_en, pipe.unet, sd, multiplier=1.0, is_animatediff=not is_sdxl) 76 | lora_network.load_state_dict(sd, False) 77 | lora_network.apply_to(1.0) 78 | self.network = lora_network 79 | 80 | self.is_valid = True 81 | 82 | self.start_scale = lcm_map["start_scale"] 83 | self.end_scale = lcm_map["end_scale"] 84 | self.gradient_start = lcm_map["gradient_start"] 85 | self.gradient_end = lcm_map["gradient_end"] 86 | 87 | 88 | def to( 89 | self, 90 | device, 91 | dtype, 92 | ): 93 | self.network.to(device=device, dtype=dtype) 94 | 95 | def apply( 96 | self, 97 | step, 98 | total_steps, 99 | ): 100 | step += 1 101 | progress = step / total_steps 102 | 103 | if progress < self.gradient_start: 104 | scale = self.start_scale 105 | elif progress > self.gradient_end: 106 | scale = self.end_scale 107 | else: 108 | if (self.gradient_end - self.gradient_start) < 1e-4: 109 | progress = 0 110 | else: 111 | progress = (progress - self.gradient_start) / (self.gradient_end - self.gradient_start) 112 | scale = (self.end_scale - self.start_scale) * progress 113 | scale += self.start_scale 114 | 115 | self.network.active( scale ) 116 | 117 | def unapply( 118 | self, 119 | ): 120 | self.network.deactive( ) 121 | 122 | 123 | 124 | class LoraMap: 125 | def __init__( 126 | self, 127 | pipe, 128 | lora_map, 129 | video_length, 130 | is_sdxl, 131 | ): 132 | self.networks = [] 133 | 134 | def create_schedule(scales, length): 135 | scales = { int(i):scales[i] for i in scales } 136 | keys = sorted(scales.keys()) 137 | 138 | if len(keys) == 1: 139 | return { i:scales[keys[0]] for i in range(length) } 140 | keys = keys + [keys[0]] 141 | 142 | schedule={} 143 | 144 | def calc(rate,start_v,end_v): 145 | return start_v + (rate * rate)*(end_v - start_v) 146 | 147 | for key_prev,key_next in zip(keys[:-1],keys[1:]): 148 | v1 = scales[key_prev] 149 | v2 = scales[key_next] 150 | if key_prev > key_next: 151 | key_next += length 152 | for i in range(key_prev,key_next): 153 | dist = i-key_prev 154 | if i >= length: 155 | i -= length 156 | schedule[i] = calc( dist/(key_next-key_prev), v1, v2 ) 157 | return schedule 158 | 159 | for lora_path in lora_map: 160 | sd = load_file(lora_path) 161 | if not sd: 162 | continue 163 | te_en = [pipe.text_encoder, pipe.text_encoder_2] if is_sdxl else pipe.text_encoder 164 | lora_network: LoRANetwork = create_network_from_weights(te_en, pipe.unet, sd, multiplier=0.75, is_animatediff=not is_sdxl) 165 | lora_network.load_state_dict(sd, False) 166 | lora_network.apply_to(0.75) 167 | 168 | self.networks.append( 169 | { 170 | "network":lora_network, 171 | "region":lora_map[lora_path]["region"], 172 | "schedule": create_schedule(lora_map[lora_path]["scale"], video_length ) 173 | } 174 | ) 175 | 176 | def region_convert(i): 177 | if i == "background": 178 | return 0 179 | else: 180 | return int(i) + 1 181 | 182 | for net in self.networks: 183 | net["region"] = [ region_convert(i) for i in net["region"] ] 184 | 185 | # for n in self.networks: 186 | # logger.info(f"{n['region']=}") 187 | # logger.info(f"{n['schedule']=}") 188 | 189 | if self.networks: 190 | self.is_valid = True 191 | else: 192 | self.is_valid = False 193 | 194 | def to( 195 | self, 196 | device, 197 | dtype, 198 | ): 199 | for net in self.networks: 200 | net["network"].to(device=device, dtype=dtype) 201 | 202 | def apply( 203 | self, 204 | cond_index, 205 | cond_nums, 206 | frame_no, 207 | ): 208 | ''' 209 | neg 0 (bg) 210 | neg 1 211 | neg 2 212 | pos 0 (bg) 213 | pos 1 214 | pos 2 215 | ''' 216 | 217 | region_index = cond_index if cond_index < cond_nums//2 else cond_index - cond_nums//2 218 | # logger.info(f"{cond_index=}") 219 | # logger.info(f"{cond_nums=}") 220 | # logger.info(f"{region_index=}") 221 | 222 | 223 | for i,net in enumerate(self.networks): 224 | if region_index in net["region"]: 225 | scale = net["schedule"][frame_no] 226 | if scale > 0: 227 | net["network"].active( scale ) 228 | # logger.info(f"{i=} active {scale=}") 229 | else: 230 | net["network"].deactive( ) 231 | # logger.info(f"{i=} DEactive") 232 | 233 | else: 234 | net["network"].deactive( ) 235 | # logger.info(f"{i=} DEactive") 236 | 237 | def unapply( 238 | self, 239 | ): 240 | 241 | for net in self.networks: 242 | net["network"].deactive( ) 243 | 244 | -------------------------------------------------------------------------------- /src/animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 8 | from einops import rearrange 9 | from torch import Tensor, nn 10 | 11 | 12 | #class InflatedConv3d(nn.Conv2d): 13 | class InflatedConv3d(LoRACompatibleConv): 14 | def forward(self, x: Tensor) -> Tensor: 15 | frames = x.shape[2] 16 | 17 | x = rearrange(x, "b c f h w -> (b f) c h w") 18 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 19 | x = rearrange(x, "(b f) c h w -> b c f h w", f=frames) 20 | return x 21 | 22 | class InflatedGroupNorm(nn.GroupNorm): 23 | def forward(self, x): 24 | video_length = x.shape[2] 25 | 26 | x = rearrange(x, "b c f h w -> (b f) c h w") 27 | x = super().forward(x) 28 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 29 | 30 | return x 31 | 32 | class Upsample3D(nn.Module): 33 | def __init__( 34 | self, 35 | channels: int, 36 | use_conv: bool = False, 37 | use_conv_transpose: bool = False, 38 | out_channels: Optional[int] = None, 39 | name="conv", 40 | ): 41 | super().__init__() 42 | self.channels = channels 43 | self.out_channels = out_channels or channels 44 | self.use_conv = use_conv 45 | self.use_conv_transpose = use_conv_transpose 46 | self.name = name 47 | 48 | if use_conv_transpose: 49 | raise NotImplementedError 50 | elif use_conv: 51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 52 | 53 | def forward(self, hidden_states: Tensor, output_size=None): 54 | assert hidden_states.shape[1] == self.channels 55 | 56 | if self.use_conv_transpose: 57 | raise NotImplementedError 58 | 59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 60 | dtype = hidden_states.dtype 61 | if dtype == torch.bfloat16: 62 | hidden_states = hidden_states.to(torch.float32) 63 | 64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 65 | if hidden_states.shape[0] >= 64: 66 | hidden_states = hidden_states.contiguous() 67 | 68 | # if `output_size` is passed we force the interpolation output 69 | # size and do not make use of `scale_factor=2` 70 | if output_size is None: 71 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 72 | else: 73 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 74 | 75 | # If the input is bfloat16, we cast back to bfloat16 76 | if dtype == torch.bfloat16: 77 | hidden_states = hidden_states.to(dtype) 78 | 79 | hidden_states = self.conv(hidden_states) 80 | 81 | return hidden_states 82 | 83 | 84 | class Downsample3D(nn.Module): 85 | def __init__( 86 | self, 87 | channels: int, 88 | use_conv: bool = False, 89 | out_channels: Optional[int] = None, 90 | padding: int = 1, 91 | name="conv", 92 | ): 93 | super().__init__() 94 | self.channels = channels 95 | self.out_channels = out_channels or channels 96 | self.use_conv = use_conv 97 | self.padding = padding 98 | stride = 2 99 | self.name = name 100 | 101 | if use_conv: 102 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 103 | else: 104 | raise NotImplementedError 105 | 106 | def forward(self, hidden_states): 107 | assert hidden_states.shape[1] == self.channels 108 | if self.use_conv and self.padding == 0: 109 | raise NotImplementedError 110 | 111 | assert hidden_states.shape[1] == self.channels 112 | hidden_states = self.conv(hidden_states) 113 | 114 | return hidden_states 115 | 116 | 117 | class ResnetBlock3D(nn.Module): 118 | def __init__( 119 | self, 120 | *, 121 | in_channels, 122 | out_channels=None, 123 | conv_shortcut=False, 124 | dropout=0.0, 125 | temb_channels=512, 126 | groups=32, 127 | groups_out=None, 128 | pre_norm=True, 129 | eps=1e-6, 130 | non_linearity="swish", 131 | time_embedding_norm="default", 132 | output_scale_factor=1.0, 133 | use_in_shortcut=None, 134 | use_inflated_groupnorm=None, 135 | ): 136 | super().__init__() 137 | self.pre_norm = pre_norm 138 | self.pre_norm = True 139 | self.in_channels = in_channels 140 | out_channels = in_channels if out_channels is None else out_channels 141 | self.out_channels = out_channels 142 | self.use_conv_shortcut = conv_shortcut 143 | self.time_embedding_norm = time_embedding_norm 144 | self.output_scale_factor = output_scale_factor 145 | 146 | if groups_out is None: 147 | groups_out = groups 148 | 149 | assert use_inflated_groupnorm != None 150 | if use_inflated_groupnorm: 151 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 152 | else: 153 | self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 154 | 155 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 156 | 157 | if temb_channels is not None: 158 | if self.time_embedding_norm == "default": 159 | time_emb_proj_out_channels = out_channels 160 | elif self.time_embedding_norm == "scale_shift": 161 | time_emb_proj_out_channels = out_channels * 2 162 | else: 163 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 164 | 165 | # self.time_emb_proj = nn.Linear(temb_channels, time_emb_proj_out_channels) 166 | self.time_emb_proj = LoRACompatibleLinear(temb_channels, time_emb_proj_out_channels) 167 | else: 168 | self.time_emb_proj = None 169 | 170 | if use_inflated_groupnorm: 171 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 172 | else: 173 | self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 174 | 175 | self.dropout = nn.Dropout(dropout) 176 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 177 | 178 | if non_linearity == "swish": 179 | self.nonlinearity = lambda x: F.silu(x) 180 | elif non_linearity == "mish": 181 | self.nonlinearity = Mish() 182 | elif non_linearity == "silu": 183 | self.nonlinearity = nn.SiLU() 184 | 185 | self.use_in_shortcut = ( 186 | self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 187 | ) 188 | 189 | self.conv_shortcut = None 190 | if self.use_in_shortcut: 191 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 192 | 193 | def forward(self, input_tensor, temb): 194 | hidden_states = input_tensor 195 | 196 | hidden_states = self.norm1(hidden_states) 197 | hidden_states = self.nonlinearity(hidden_states) 198 | 199 | hidden_states = self.conv1(hidden_states) 200 | 201 | if temb is not None: 202 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 203 | 204 | if temb is not None and self.time_embedding_norm == "default": 205 | hidden_states = hidden_states + temb 206 | 207 | hidden_states = self.norm2(hidden_states) 208 | 209 | if temb is not None and self.time_embedding_norm == "scale_shift": 210 | scale, shift = torch.chunk(temb, 2, dim=1) 211 | hidden_states = hidden_states * (1 + scale) + shift 212 | 213 | hidden_states = self.nonlinearity(hidden_states) 214 | 215 | hidden_states = self.dropout(hidden_states) 216 | hidden_states = self.conv2(hidden_states) 217 | 218 | if self.conv_shortcut is not None: 219 | input_tensor = self.conv_shortcut(input_tensor) 220 | 221 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 222 | 223 | return output_tensor 224 | 225 | 226 | class Mish(nn.Module): 227 | def forward(self, hidden_states): 228 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 229 | -------------------------------------------------------------------------------- /config/prompts/img2img_sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sample", 3 | "path": "share/Stable-diffusion/mistoonAnime_v20.safetensors", 4 | "motion_module": "models/motion-module/mm_sd_v15_v2.ckpt", 5 | "compile": false, 6 | "seed": [ 7 | 12345 8 | ], 9 | "scheduler": "k_dpmpp_sde", 10 | "steps": 20, 11 | "guidance_scale": 10, 12 | "unet_batch_size": 1, 13 | "clip_skip": 2, 14 | "prompt_fixed_ratio": 0.5, 15 | "head_prompt": "(style of studio ghibli:1.2), (masterpiece, best quality)", 16 | "prompt_map": { 17 | "0": "forest, water, river, outdoors," 18 | }, 19 | "tail_prompt": "", 20 | "n_prompt": [ 21 | "(worst quality:2), (bad quality:2), (normal quality:2), lowers, bad anatomy, bad hands, (multiple views)," 22 | ], 23 | "lora_map": { 24 | "share/models/Lora/Ghibli_v6.safetensors": 1.0 25 | }, 26 | "motion_lora_map": { 27 | }, 28 | "ip_adapter_map": { 29 | "enable": false, 30 | "input_image_dir": "", 31 | "prompt_fixed_ratio": 0.5, 32 | "save_input_image": true, 33 | "resized_to_square": false, 34 | "scale": 0.5, 35 | "is_plus_face": false, 36 | "is_plus": true, 37 | "is_light": false 38 | }, 39 | "img2img_map":{ 40 | "enable": true, 41 | "init_img_dir" : "init_imgs/sample0", 42 | "save_init_image": true, 43 | "denoising_strength" : 0.85 44 | }, 45 | "controlnet_map": { 46 | "input_image_dir" : "", 47 | "max_samples_on_vram": 0, 48 | "max_models_on_vram" : 1, 49 | "save_detectmap": true, 50 | "preprocess_on_gpu": true, 51 | "is_loop": true, 52 | 53 | "controlnet_tile":{ 54 | "enable": true, 55 | "use_preprocessor":true, 56 | "preprocessor":{ 57 | "type" : "none", 58 | "param":{ 59 | } 60 | }, 61 | "guess_mode":false, 62 | "controlnet_conditioning_scale": 1.0, 63 | "control_guidance_start": 0.0, 64 | "control_guidance_end": 1.0, 65 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 66 | }, 67 | "controlnet_ip2p":{ 68 | "enable": true, 69 | "use_preprocessor":true, 70 | "guess_mode":false, 71 | "controlnet_conditioning_scale": 1.0, 72 | "control_guidance_start": 0.0, 73 | "control_guidance_end": 1.0, 74 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 75 | }, 76 | "controlnet_lineart_anime":{ 77 | "enable": true, 78 | "use_preprocessor":true, 79 | "guess_mode":false, 80 | "controlnet_conditioning_scale": 1.0, 81 | "control_guidance_start": 0.0, 82 | "control_guidance_end": 1.0, 83 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 84 | }, 85 | "controlnet_openpose":{ 86 | "enable": true, 87 | "use_preprocessor":true, 88 | "guess_mode":false, 89 | "controlnet_conditioning_scale": 1.0, 90 | "control_guidance_start": 0.0, 91 | "control_guidance_end": 1.0, 92 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 93 | }, 94 | "controlnet_softedge":{ 95 | "enable": true, 96 | "use_preprocessor":true, 97 | "preprocessor":{ 98 | "type" : "softedge_pidsafe", 99 | "param":{ 100 | } 101 | }, 102 | "guess_mode":false, 103 | "controlnet_conditioning_scale": 1.0, 104 | "control_guidance_start": 0.0, 105 | "control_guidance_end": 1.0, 106 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 107 | }, 108 | "controlnet_shuffle": { 109 | "enable": true, 110 | "use_preprocessor":true, 111 | "guess_mode":false, 112 | "controlnet_conditioning_scale": 1.0, 113 | "control_guidance_start": 0.0, 114 | "control_guidance_end": 1.0, 115 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 116 | }, 117 | "controlnet_depth": { 118 | "enable": true, 119 | "use_preprocessor":true, 120 | "guess_mode":false, 121 | "controlnet_conditioning_scale": 1.0, 122 | "control_guidance_start": 0.0, 123 | "control_guidance_end": 1.0, 124 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 125 | }, 126 | "controlnet_canny": { 127 | "enable": true, 128 | "use_preprocessor":true, 129 | "guess_mode":false, 130 | "controlnet_conditioning_scale": 1.0, 131 | "control_guidance_start": 0.0, 132 | "control_guidance_end": 1.0, 133 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 134 | }, 135 | "controlnet_inpaint": { 136 | "enable": true, 137 | "use_preprocessor":true, 138 | "guess_mode":false, 139 | "controlnet_conditioning_scale": 1.0, 140 | "control_guidance_start": 0.0, 141 | "control_guidance_end": 1.0, 142 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 143 | }, 144 | "controlnet_lineart": { 145 | "enable": true, 146 | "use_preprocessor":true, 147 | "guess_mode":false, 148 | "controlnet_conditioning_scale": 1.0, 149 | "control_guidance_start": 0.0, 150 | "control_guidance_end": 1.0, 151 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 152 | }, 153 | "controlnet_mlsd": { 154 | "enable": true, 155 | "use_preprocessor":true, 156 | "guess_mode":false, 157 | "controlnet_conditioning_scale": 1.0, 158 | "control_guidance_start": 0.0, 159 | "control_guidance_end": 1.0, 160 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 161 | }, 162 | "controlnet_normalbae": { 163 | "enable": true, 164 | "use_preprocessor":true, 165 | "guess_mode":false, 166 | "controlnet_conditioning_scale": 1.0, 167 | "control_guidance_start": 0.0, 168 | "control_guidance_end": 1.0, 169 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 170 | }, 171 | "controlnet_scribble": { 172 | "enable": true, 173 | "use_preprocessor":true, 174 | "guess_mode":false, 175 | "controlnet_conditioning_scale": 1.0, 176 | "control_guidance_start": 0.0, 177 | "control_guidance_end": 1.0, 178 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 179 | }, 180 | "controlnet_seg": { 181 | "enable": true, 182 | "use_preprocessor":true, 183 | "guess_mode":false, 184 | "controlnet_conditioning_scale": 1.0, 185 | "control_guidance_start": 0.0, 186 | "control_guidance_end": 1.0, 187 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 188 | }, 189 | "qr_code_monster_v1": { 190 | "enable": true, 191 | "use_preprocessor":true, 192 | "guess_mode":false, 193 | "controlnet_conditioning_scale": 1.0, 194 | "control_guidance_start": 0.0, 195 | "control_guidance_end": 1.0, 196 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 197 | }, 198 | "qr_code_monster_v2": { 199 | "enable": true, 200 | "use_preprocessor":true, 201 | "guess_mode":false, 202 | "controlnet_conditioning_scale": 1.0, 203 | "control_guidance_start": 0.0, 204 | "control_guidance_end": 1.0, 205 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 206 | }, 207 | "controlnet_mediapipe_face": { 208 | "enable": true, 209 | "use_preprocessor":true, 210 | "guess_mode":false, 211 | "controlnet_conditioning_scale": 1.0, 212 | "control_guidance_start": 0.0, 213 | "control_guidance_end": 1.0, 214 | "control_scale_list":[0.5,0.4,0.3,0.2,0.1] 215 | }, 216 | "controlnet_ref": { 217 | "enable": false, 218 | "ref_image": "ref_image/ref_sample.png", 219 | "attention_auto_machine_weight": 0.3, 220 | "gn_auto_machine_weight": 0.3, 221 | "style_fidelity": 0.5, 222 | "reference_attn": true, 223 | "reference_adain": false, 224 | "scale_pattern":[1.0] 225 | } 226 | }, 227 | "upscale_config": { 228 | "scheduler": "k_dpmpp_sde", 229 | "steps": 20, 230 | "strength": 0.5, 231 | "guidance_scale": 10, 232 | "controlnet_tile": { 233 | "enable": true, 234 | "controlnet_conditioning_scale": 1.0, 235 | "guess_mode": false, 236 | "control_guidance_start": 0.0, 237 | "control_guidance_end": 1.0 238 | }, 239 | "controlnet_line_anime": { 240 | "enable": false, 241 | "controlnet_conditioning_scale": 1.0, 242 | "guess_mode": false, 243 | "control_guidance_start": 0.0, 244 | "control_guidance_end": 1.0 245 | }, 246 | "controlnet_ip2p": { 247 | "enable": false, 248 | "controlnet_conditioning_scale": 0.5, 249 | "guess_mode": false, 250 | "control_guidance_start": 0.0, 251 | "control_guidance_end": 1.0 252 | }, 253 | "controlnet_ref": { 254 | "enable": false, 255 | "use_frame_as_ref_image": false, 256 | "use_1st_frame_as_ref_image": false, 257 | "ref_image": "ref_image/path_to_your_ref_img.jpg", 258 | "attention_auto_machine_weight": 1.0, 259 | "gn_auto_machine_weight": 1.0, 260 | "style_fidelity": 0.25, 261 | "reference_attn": true, 262 | "reference_adain": false 263 | } 264 | }, 265 | "output":{ 266 | "format" : "mp4", 267 | "fps" : 8, 268 | "encode_param":{ 269 | "crf": 10 270 | } 271 | } 272 | } 273 | --------------------------------------------------------------------------------