├── .gitignore ├── .gitmodules ├── CHANGES ├── Dockerfile ├── LICENSE ├── README-zh.md ├── README.md ├── configs ├── model │ ├── T2I_all_model.py │ ├── ip_adapter.py │ ├── lcm_model.py │ ├── motion_model.py │ ├── negative_prompt.py │ └── referencenet.py └── tasks │ └── example.yaml ├── data ├── demo │ ├── cyber_girl.png │ └── video1.mp4 ├── images │ ├── Mona_Lisa.jpg │ ├── Portrait-of-Dr.-Gachet.jpg │ ├── Self-Portrait-with-Cropped-Hair.jpg │ ├── The-Laughing-Cavalier.jpg │ ├── boy_play_guitar.jpeg │ ├── boy_play_guitar2.jpeg │ ├── cyber_girl.png │ ├── duffy.png │ ├── dufu.jpeg │ ├── girl_play_guitar2.jpeg │ ├── girl_play_guitar4.jpeg │ ├── jinkesi2.jpeg │ ├── river.jpeg │ ├── seaside2.jpeg │ ├── seaside4.jpeg │ ├── seaside_girl.jpeg │ ├── spark_girl.png │ ├── waterfall4.jpeg │ └── yongen.jpeg ├── models │ ├── musev_structure.png │ └── parallel_denoise.png ├── result_video │ ├── Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.mp4 │ ├── Portrait-of-Dr.-Gachet.mp4 │ ├── Self-Portrait-with-Cropped-Hair.mp4 │ ├── The-Laughing-Cavalier.mp4 │ ├── boy_play_guitar.mp4 │ ├── boy_play_guitar2.mp4 │ ├── dufu.mp4 │ ├── girl_play_guitar2.mp4 │ ├── girl_play_guitar4.mp4 │ ├── jinkesi2.mp4 │ ├── river.mp4 │ ├── seaside2.mp4 │ ├── seaside4.mp4 │ ├── seaside_girl.mp4 │ ├── waterfall4.mp4 │ └── yongen.mp4 └── source_video │ ├── pose-for-Duffy-4.mp4 │ └── video1_girl_poseseq.mp4 ├── environment.yml ├── musev ├── __init__.py ├── auto_prompt │ ├── __init__.py │ ├── attributes │ │ ├── __init__.py │ │ ├── attr2template.py │ │ ├── attributes.py │ │ ├── human.py │ │ ├── render.py │ │ └── style.py │ ├── human.py │ ├── load_template.py │ └── util.py ├── data │ ├── __init__.py │ └── data_util.py ├── logging.conf ├── models │ ├── __init__.py │ ├── attention.py │ ├── attention_processor.py │ ├── controlnet.py │ ├── embeddings.py │ ├── facein_loader.py │ ├── ip_adapter_face_loader.py │ ├── ip_adapter_loader.py │ ├── referencenet.py │ ├── referencenet_loader.py │ ├── resnet.py │ ├── super_model.py │ ├── temporal_transformer.py │ ├── text_model.py │ ├── transformer_2d.py │ ├── unet_2d_blocks.py │ ├── unet_3d_blocks.py │ ├── unet_3d_condition.py │ └── unet_loader.py ├── pipelines │ ├── __init__.py │ ├── context.py │ ├── pipeline_controlnet.py │ └── pipeline_controlnet_predictor.py ├── schedulers │ ├── __init__.py │ ├── scheduling_ddim.py │ ├── scheduling_ddpm.py │ ├── scheduling_dpmsolver_multistep.py │ ├── scheduling_euler_ancestral_discrete.py │ ├── scheduling_euler_discrete.py │ └── scheduling_lcm.py └── utils │ ├── __init__.py │ ├── attention_util.py │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ ├── model_util.py │ ├── noise_util.py │ ├── register.py │ ├── tensor_util.py │ ├── text_emb_util.py │ ├── timesteps_util.py │ ├── util.py │ └── vae_util.py ├── requirements.txt ├── scripts ├── gradio │ ├── Dockerfile │ ├── app.py │ ├── app_docker_space.py │ ├── app_gradio_space.py │ ├── entrypoint.sh │ ├── gradio_text2video.py │ └── gradio_video2video.py └── inference │ ├── text2video.py │ └── video2video.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | *.swp 126 | .*.swp 127 | 128 | .DS_Store 129 | 130 | # project 131 | outputs/ 132 | results/ 133 | scripts/codetest/ 134 | # configs/train/video_creation_anchorxia_* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "MMCM"] 2 | path = MMCM 3 | url = https://github.com/TMElyralab/MMCM.git 4 | [submodule "controlnet_aux"] 5 | path = controlnet_aux 6 | url = https://github.com/TMElyralab/controlnet_aux.git 7 | branch = tme 8 | [submodule "diffusers"] 9 | path = diffusers 10 | url = https://github.com/TMElyralab/diffusers.git 11 | branch = tme 12 | -------------------------------------------------------------------------------- /CHANGES: -------------------------------------------------------------------------------- 1 | Version 1.0.0 (2024.03.27) 2 | 3 | * init musev, support video generation with text and image 4 | * controlnet_aux: enrich interface and function of dwpose. 5 | * diffusers: controlnet support latent instead of images only. -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM anchorxia/musev:1.0.0 2 | 3 | #MAINTAINER 维护者信息 4 | LABEL MAINTAINER="anchorxia" 5 | LABEL Email="anchorxia@tencent.com" 6 | LABEL Description="musev gpu runtime image, base docker is pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel" 7 | ARG DEBIAN_FRONTEND=noninteractive 8 | 9 | USER root 10 | 11 | SHELL ["/bin/bash", "--login", "-c"] 12 | 13 | RUN . /opt/conda/etc/profile.d/conda.sh \ 14 | && echo "source activate musev" >> ~/.bashrc \ 15 | && conda activate musev \ 16 | && conda env list \ 17 | && pip --no-cache-dir install cuid gradio==4.12 spaces 18 | USER root 19 | -------------------------------------------------------------------------------- /configs/model/T2I_all_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | T2IDir = os.path.join( 5 | os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "t2i" 6 | ) 7 | 8 | MODEL_CFG = { 9 | "majicmixRealv6Fp16": { 10 | "sd": os.path.join(T2IDir, "sd1.5/majicmixRealv6Fp16"), 11 | }, 12 | "fantasticmix_v10": { 13 | "sd": os.path.join(T2IDir, "sd1.5/fantasticmix_v10"), 14 | }, 15 | } 16 | -------------------------------------------------------------------------------- /configs/model/ip_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | IPAdapterModelDir = os.path.join( 4 | os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "IP-Adapter" 5 | ) 6 | 7 | 8 | MotionDir = os.path.join( 9 | os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion" 10 | ) 11 | 12 | 13 | MODEL_CFG = { 14 | "IPAdapter": { 15 | "ip_image_encoder": os.path.join(IPAdapterModelDir, "models/image_encoder"), 16 | "ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter_sd15.bin"), 17 | "ip_scale": 1.0, 18 | "clip_extra_context_tokens": 4, 19 | "clip_embeddings_dim": 1024, 20 | "desp": "", 21 | }, 22 | "IPAdapterPlus": { 23 | "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"), 24 | "ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-plus_sd15.bin"), 25 | "ip_scale": 1.0, 26 | "clip_extra_context_tokens": 16, 27 | "clip_embeddings_dim": 1024, 28 | "desp": "", 29 | }, 30 | "IPAdapterPlus-face": { 31 | "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"), 32 | "ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-plus-face_sd15.bin"), 33 | "ip_scale": 1.0, 34 | "clip_extra_context_tokens": 16, 35 | "clip_embeddings_dim": 1024, 36 | "desp": "", 37 | }, 38 | "IPAdapterFaceID": { 39 | "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"), 40 | "ip_ckpt": os.path.join(IPAdapterModelDir, "ip-adapter-faceid_sd15.bin"), 41 | "ip_scale": 1.0, 42 | "clip_extra_context_tokens": 4, 43 | "clip_embeddings_dim": 512, 44 | "desp": "", 45 | }, 46 | "musev_referencenet": { 47 | "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"), 48 | "ip_ckpt": os.path.join( 49 | MotionDir, "musev_referencenet/ip_adapter_image_proj.bin" 50 | ), 51 | "ip_scale": 1.0, 52 | "clip_extra_context_tokens": 4, 53 | "clip_embeddings_dim": 1024, 54 | "desp": "", 55 | }, 56 | "musev_referencenet_pose": { 57 | "ip_image_encoder": os.path.join(IPAdapterModelDir, "image_encoder"), 58 | "ip_ckpt": os.path.join( 59 | MotionDir, "musev_referencenet_pose/ip_adapter_image_proj.bin" 60 | ), 61 | "ip_scale": 1.0, 62 | "clip_extra_context_tokens": 4, 63 | "clip_embeddings_dim": 1024, 64 | "desp": "", 65 | }, 66 | } 67 | -------------------------------------------------------------------------------- /configs/model/lcm_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | LCMDir = os.path.join( 5 | os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "lcm" 6 | ) 7 | 8 | 9 | MODEL_CFG = { 10 | "lcm": { 11 | os.path.join(LCMDir, "lcm-lora-sdv1-5/pytorch_lora_weights.safetensors"): { 12 | "strength": 1.0, 13 | "lora_block_weight": "ALL", 14 | "strength_offset": 0, 15 | }, 16 | }, 17 | } 18 | -------------------------------------------------------------------------------- /configs/model/motion_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | MotionDIr = os.path.join( 5 | os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion" 6 | ) 7 | 8 | 9 | MODEL_CFG = { 10 | "musev": { 11 | "unet": os.path.join(MotionDIr, "musev"), 12 | "desp": "only train unet motion module, fix t2i", 13 | }, 14 | "musev_referencenet": { 15 | "unet": os.path.join(MotionDIr, "musev_referencenet"), 16 | "desp": "train referencenet, IPAdapter and unet motion module, fix t2i", 17 | }, 18 | "musev_referencenet_pose": { 19 | "unet": os.path.join(MotionDIr, "musev_referencenet_pose"), 20 | "desp": "train unet motion module and IPAdapter, fix t2i and referencenet", 21 | }, 22 | } 23 | -------------------------------------------------------------------------------- /configs/model/negative_prompt.py: -------------------------------------------------------------------------------- 1 | Negative_Prompt_CFG = { 2 | "Empty": { 3 | "base_model": "", 4 | "prompt": "", 5 | "refer": "", 6 | }, 7 | "V1": { 8 | "base_model": "", 9 | "prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, tail, watermarks", 10 | "refer": "", 11 | }, 12 | "V2": { 13 | "base_model": "", 14 | "prompt": "badhandv4, ng_deepnegative_v1_75t, (((multiple heads))), (((bad body))), (((two people))), ((extra arms)), ((deformed body)), (((sexy))), paintings,(((two heads))), ((big head)),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (((nsfw))), nipples, extra fingers, (extra legs), (long neck), mutated hands, (fused fingers), (too many fingers)", 15 | "refer": "Weiban", 16 | }, 17 | "V3": { 18 | "base_model": "", 19 | "prompt": "badhandv4, ng_deepnegative_v1_75t, bad quality", 20 | "refer": "", 21 | }, 22 | "V4": { 23 | "base_model": "", 24 | "prompt": "badhandv4,ng_deepnegative_v1_75t,EasyNegativeV2,bad_prompt_version2-neg,bad quality", 25 | "refer": "", 26 | }, 27 | "V5": { 28 | "base_model": "", 29 | "prompt": "(((multiple heads))), (((bad body))), (((two people))), ((extra arms)), ((deformed body)), (((sexy))), paintings,(((two heads))), ((big head)),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (((nsfw))), nipples, extra fingers, (extra legs), (long neck), mutated hands, (fused fingers), (too many fingers)", 30 | "refer": "Weiban", 31 | }, 32 | } 33 | -------------------------------------------------------------------------------- /configs/model/referencenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | MotionDIr = os.path.join( 5 | os.path.dirname(os.path.abspath(__file__)), "../../checkpoints", "motion" 6 | ) 7 | 8 | 9 | MODEL_CFG = { 10 | "musev_referencenet": { 11 | "net": os.path.join(MotionDIr, "musev_referencenet"), 12 | "desp": "", 13 | }, 14 | } 15 | -------------------------------------------------------------------------------- /configs/tasks/example.yaml: -------------------------------------------------------------------------------- 1 | # - name: task_name 2 | # condition_images: vision condition images path 3 | # video_path: str, default null, used for video2video 4 | # prompt: text to guide image generation 5 | # ipadapter_image: image_path for IP-Apdater 6 | # refer_image: image_path for referencenet, generally speaking, same as ipadapter_image 7 | # height: int # The shorter the image size, the larger the motion amplitude, and the lower video quality. 8 | # width: int # The longer the W&H, the smaller the motion amplitude, and the higher video quality. 9 | # img_length_ratio: float, generation video size is (height, width) * img_length_ratio 10 | 11 | # text/image2video 12 | - condition_images: ./data/images/yongen.jpeg 13 | eye_blinks_factor: 1.8 14 | height: 1308 15 | img_length_ratio: 0.957 16 | ipadapter_image: ${.condition_images} 17 | name: yongen 18 | prompt: (masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3) 19 | refer_image: ${.condition_images} 20 | video_path: null 21 | width: 736 22 | - condition_images: ./data/images/jinkesi2.jpeg 23 | eye_blinks_factor: 1.8 24 | height: 714 25 | img_length_ratio: 1.25 26 | ipadapter_image: ${.condition_images} 27 | name: jinkesi2 28 | prompt: (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face, 29 | soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3) 30 | refer_image: ${.condition_images} 31 | video_path: null 32 | width: 563 33 | - condition_images: ./data/images/seaside4.jpeg 34 | eye_blinks_factor: 1.8 35 | height: 317 36 | img_length_ratio: 2.221 37 | ipadapter_image: ${.condition_images} 38 | name: seaside4 39 | prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene 40 | refer_image: ${.condition_images} 41 | video_path: null 42 | width: 564 43 | - condition_images: ./data/images/seaside_girl.jpeg 44 | eye_blinks_factor: 1.8 45 | height: 736 46 | img_length_ratio: 0.957 47 | ipadapter_image: ${.condition_images} 48 | name: seaside_girl 49 | prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene 50 | refer_image: ${.condition_images} 51 | video_path: null 52 | width: 736 53 | - condition_images: ./data/images/boy_play_guitar.jpeg 54 | eye_blinks_factor: 1.8 55 | height: 846 56 | img_length_ratio: 1.248 57 | ipadapter_image: ${.condition_images} 58 | name: boy_play_guitar 59 | prompt: (masterpiece, best quality, highres:1), playing guitar 60 | refer_image: ${.condition_images} 61 | video_path: null 62 | width: 564 63 | - condition_images: ./data/images/girl_play_guitar2.jpeg 64 | eye_blinks_factor: 1.8 65 | height: 1002 66 | img_length_ratio: 1.248 67 | ipadapter_image: ${.condition_images} 68 | name: girl_play_guitar2 69 | prompt: (masterpiece, best quality, highres:1), playing guitar 70 | refer_image: ${.condition_images} 71 | video_path: null 72 | width: 564 73 | - condition_images: ./data/images/boy_play_guitar2.jpeg 74 | eye_blinks_factor: 1.8 75 | height: 630 76 | img_length_ratio: 1.676 77 | ipadapter_image: ${.condition_images} 78 | name: boy_play_guitar2 79 | prompt: (masterpiece, best quality, highres:1), playing guitar 80 | refer_image: ${.condition_images} 81 | video_path: null 82 | width: 420 83 | - condition_images: ./data/images/girl_play_guitar4.jpeg 84 | eye_blinks_factor: 1.8 85 | height: 846 86 | img_length_ratio: 1.248 87 | ipadapter_image: ${.condition_images} 88 | name: girl_play_guitar4 89 | prompt: (masterpiece, best quality, highres:1), playing guitar 90 | refer_image: ${.condition_images} 91 | video_path: null 92 | width: 564 93 | - condition_images: ./data/images/dufu.jpeg 94 | eye_blinks_factor: 1.8 95 | height: 500 96 | img_length_ratio: 1.495 97 | ipadapter_image: ${.condition_images} 98 | name: dufu 99 | prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3),Chinese ink painting style 100 | refer_image: ${.condition_images} 101 | video_path: null 102 | width: 471 103 | - condition_images: ./data/images/Mona_Lisa..jpg 104 | eye_blinks_factor: 1.8 105 | height: 894 106 | img_length_ratio: 1.173 107 | ipadapter_image: ${.condition_images} 108 | name: Mona_Lisa. 109 | prompt: (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face, 110 | soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3) 111 | refer_image: ${.condition_images} 112 | video_path: null 113 | width: 600 114 | - condition_images: ./data/images/Portrait-of-Dr.-Gachet.jpg 115 | eye_blinks_factor: 1.8 116 | height: 985 117 | img_length_ratio: 0.88 118 | ipadapter_image: ${.condition_images} 119 | name: Portrait-of-Dr.-Gachet 120 | prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3) 121 | refer_image: ${.condition_images} 122 | video_path: null 123 | width: 800 124 | - condition_images: ./data/images/Self-Portrait-with-Cropped-Hair.jpg 125 | eye_blinks_factor: 1.8 126 | height: 565 127 | img_length_ratio: 1.246 128 | ipadapter_image: ${.condition_images} 129 | name: Self-Portrait-with-Cropped-Hair 130 | prompt: (masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3), animate 131 | refer_image: ${.condition_images} 132 | video_path: null 133 | width: 848 134 | - condition_images: ./data/images/The-Laughing-Cavalier.jpg 135 | eye_blinks_factor: 1.8 136 | height: 1462 137 | img_length_ratio: 0.587 138 | ipadapter_image: ${.condition_images} 139 | name: The-Laughing-Cavalier 140 | prompt: (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3) 141 | refer_image: ${.condition_images} 142 | video_path: null 143 | width: 1200 144 | 145 | # scene 146 | - condition_images: ./data/images/waterfall4.jpeg 147 | eye_blinks_factor: 1.8 148 | height: 846 149 | img_length_ratio: 1.248 150 | ipadapter_image: ${.condition_images} 151 | name: waterfall4 152 | prompt: (masterpiece, best quality, highres:1), peaceful beautiful waterfall, an 153 | endless waterfall 154 | refer_image: ${.condition_images} 155 | video_path: null 156 | width: 564 157 | - condition_images: ./data/images/river.jpeg 158 | eye_blinks_factor: 1.8 159 | height: 736 160 | img_length_ratio: 0.957 161 | ipadapter_image: ${.condition_images} 162 | name: river 163 | prompt: (masterpiece, best quality, highres:1), peaceful beautiful river 164 | refer_image: ${.condition_images} 165 | video_path: null 166 | width: 736 167 | - condition_images: ./data/images/seaside2.jpeg 168 | eye_blinks_factor: 1.8 169 | height: 1313 170 | img_length_ratio: 0.957 171 | ipadapter_image: ${.condition_images} 172 | name: seaside2 173 | prompt: (masterpiece, best quality, highres:1), peaceful beautiful sea scene 174 | refer_image: ${.condition_images} 175 | video_path: null 176 | width: 736 177 | 178 | # video2video 179 | - name: "dance1" 180 | prompt: "(masterpiece, best quality, highres:1) , a girl is dancing, wearing a dress made of stars, animation" 181 | video_path: ./data/source_video/video1_girl_poseseq.mp4 182 | condition_images: ./data/images/spark_girl.png 183 | refer_image: ${.condition_images} 184 | ipadapter_image: ${.condition_images} 185 | height: 960 186 | width: 512 187 | img_length_ratio: 1.0 188 | video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video 189 | 190 | - name: "dance2" 191 | prompt: "(best quality), ((masterpiece)), (highres), illustration, original, extremely detailed wallpaper" 192 | video_path: ./data/source_video/video1_girl_poseseq.mp4 193 | condition_images: ./data/images/cyber_girl.png 194 | refer_image: ${.condition_images} 195 | ipadapter_image: ${.condition_images} 196 | height: 960 197 | width: 512 198 | img_length_ratio: 1.0 199 | video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video 200 | 201 | - name: "duffy" 202 | prompt: "(best quality), ((masterpiece)), (highres), illustration, original, extremely detailed wallpaper" 203 | video_path: ./data/source_video/pose-for-Duffy-4.mp4 204 | condition_images: ./data/images/duffy.png 205 | refer_image: ${.condition_images} 206 | ipadapter_image: ${.condition_images} 207 | height: 1280 208 | width: 704 209 | img_length_ratio: 1.0 210 | video_is_middle: True # if true, means video_path is controlnet condition, not natural rgb video -------------------------------------------------------------------------------- /data/demo/cyber_girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/demo/cyber_girl.png -------------------------------------------------------------------------------- /data/demo/video1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/demo/video1.mp4 -------------------------------------------------------------------------------- /data/images/Mona_Lisa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/Mona_Lisa.jpg -------------------------------------------------------------------------------- /data/images/Portrait-of-Dr.-Gachet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/Portrait-of-Dr.-Gachet.jpg -------------------------------------------------------------------------------- /data/images/Self-Portrait-with-Cropped-Hair.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/Self-Portrait-with-Cropped-Hair.jpg -------------------------------------------------------------------------------- /data/images/The-Laughing-Cavalier.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/The-Laughing-Cavalier.jpg -------------------------------------------------------------------------------- /data/images/boy_play_guitar.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/boy_play_guitar.jpeg -------------------------------------------------------------------------------- /data/images/boy_play_guitar2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/boy_play_guitar2.jpeg -------------------------------------------------------------------------------- /data/images/cyber_girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/cyber_girl.png -------------------------------------------------------------------------------- /data/images/duffy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/duffy.png -------------------------------------------------------------------------------- /data/images/dufu.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/dufu.jpeg -------------------------------------------------------------------------------- /data/images/girl_play_guitar2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/girl_play_guitar2.jpeg -------------------------------------------------------------------------------- /data/images/girl_play_guitar4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/girl_play_guitar4.jpeg -------------------------------------------------------------------------------- /data/images/jinkesi2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/jinkesi2.jpeg -------------------------------------------------------------------------------- /data/images/river.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/river.jpeg -------------------------------------------------------------------------------- /data/images/seaside2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/seaside2.jpeg -------------------------------------------------------------------------------- /data/images/seaside4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/seaside4.jpeg -------------------------------------------------------------------------------- /data/images/seaside_girl.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/seaside_girl.jpeg -------------------------------------------------------------------------------- /data/images/spark_girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/spark_girl.png -------------------------------------------------------------------------------- /data/images/waterfall4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/waterfall4.jpeg -------------------------------------------------------------------------------- /data/images/yongen.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/images/yongen.jpeg -------------------------------------------------------------------------------- /data/models/musev_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/models/musev_structure.png -------------------------------------------------------------------------------- /data/models/parallel_denoise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/models/parallel_denoise.png -------------------------------------------------------------------------------- /data/result_video/Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.mp4 -------------------------------------------------------------------------------- /data/result_video/Portrait-of-Dr.-Gachet.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/Portrait-of-Dr.-Gachet.mp4 -------------------------------------------------------------------------------- /data/result_video/Self-Portrait-with-Cropped-Hair.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/Self-Portrait-with-Cropped-Hair.mp4 -------------------------------------------------------------------------------- /data/result_video/The-Laughing-Cavalier.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/The-Laughing-Cavalier.mp4 -------------------------------------------------------------------------------- /data/result_video/boy_play_guitar.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/boy_play_guitar.mp4 -------------------------------------------------------------------------------- /data/result_video/boy_play_guitar2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/boy_play_guitar2.mp4 -------------------------------------------------------------------------------- /data/result_video/dufu.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/dufu.mp4 -------------------------------------------------------------------------------- /data/result_video/girl_play_guitar2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/girl_play_guitar2.mp4 -------------------------------------------------------------------------------- /data/result_video/girl_play_guitar4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/girl_play_guitar4.mp4 -------------------------------------------------------------------------------- /data/result_video/jinkesi2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/jinkesi2.mp4 -------------------------------------------------------------------------------- /data/result_video/river.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/river.mp4 -------------------------------------------------------------------------------- /data/result_video/seaside2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/seaside2.mp4 -------------------------------------------------------------------------------- /data/result_video/seaside4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/seaside4.mp4 -------------------------------------------------------------------------------- /data/result_video/seaside_girl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/seaside_girl.mp4 -------------------------------------------------------------------------------- /data/result_video/waterfall4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/waterfall4.mp4 -------------------------------------------------------------------------------- /data/result_video/yongen.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/result_video/yongen.mp4 -------------------------------------------------------------------------------- /data/source_video/pose-for-Duffy-4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/source_video/pose-for-Duffy-4.mp4 -------------------------------------------------------------------------------- /data/source_video/video1_girl_poseseq.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/data/source_video/video1_girl_poseseq.mp4 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: musev 2 | channels: 3 | - https://repo.anaconda.com/pkgs/main 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - bzip2=1.0.8=h7b6447c_0 9 | - ca-certificates=2023.12.12=h06a4308_0 10 | - ld_impl_linux-64=2.38=h1181459_1 11 | - libffi=3.3=he6710b0_2 12 | - libgcc-ng=11.2.0=h1234567_1 13 | - libgomp=11.2.0=h1234567_1 14 | - libstdcxx-ng=11.2.0=h1234567_1 15 | - libuuid=1.41.5=h5eee18b_0 16 | - ncurses=6.4=h6a678d5_0 17 | - openssl=1.1.1w=h7f8727e_0 18 | - python=3.10.6=haa1d7c7_1 19 | - readline=8.2=h5eee18b_0 20 | - sqlite=3.41.2=h5eee18b_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - xz=5.4.5=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==2.1.0 26 | - accelerate==0.22.0 27 | - addict==2.4.0 28 | - aiofiles==23.2.1 29 | - aiohttp==3.9.1 30 | - aiosignal==1.3.1 31 | - albumentations==1.3.1 32 | - aliyun-python-sdk-core==2.14.0 33 | - aliyun-python-sdk-kms==2.16.2 34 | - altair==5.2.0 35 | - antlr4-python3-runtime==4.9.3 36 | - anyio==4.2.0 37 | - appdirs==1.4.4 38 | - argparse==1.4.0 39 | - asttokens==2.4.1 40 | - astunparse==1.6.3 41 | - async-timeout==4.0.3 42 | - attrs==23.2.0 43 | - audioread==3.0.1 44 | - basicsr==1.4.2 45 | - beautifulsoup4==4.12.2 46 | - bitsandbytes==0.41.1 47 | - black==23.12.1 48 | - blinker==1.7.0 49 | - braceexpand==0.1.7 50 | - cachetools==5.3.2 51 | - certifi==2023.11.17 52 | - cffi==1.16.0 53 | - charset-normalizer==3.3.2 54 | - chumpy==0.70 55 | - click==8.1.7 56 | - cmake==3.28.1 57 | - colorama==0.4.6 58 | - coloredlogs==15.0.1 59 | - comm==0.2.1 60 | - contourpy==1.2.0 61 | - cos-python-sdk-v5==1.9.22 62 | - coscmd==1.8.6.30 63 | - crcmod==1.7 64 | - cryptography==41.0.7 65 | - cycler==0.12.1 66 | - cython==3.0.2 67 | - datetime==5.4 68 | - debugpy==1.8.0 69 | - decorator==4.4.2 70 | - decord==0.6.0 71 | - dill==0.3.7 72 | - docker-pycreds==0.4.0 73 | - dulwich==0.21.7 74 | - easydict==1.11 75 | - einops==0.7.0 76 | - exceptiongroup==1.2.0 77 | - executing==2.0.1 78 | - fastapi==0.109.0 79 | - ffmpeg==1.4 80 | - ffmpeg-python==0.2.0 81 | - ffmpy==0.3.1 82 | - filelock==3.13.1 83 | - flatbuffers==23.5.26 84 | - fonttools==4.47.2 85 | - frozenlist==1.4.1 86 | - fsspec==2023.12.2 87 | - ftfy==6.1.1 88 | - future==0.18.3 89 | - fuzzywuzzy==0.18.0 90 | - fvcore==0.1.5.post20221221 91 | - gast==0.4.0 92 | - gdown==4.5.3 93 | - gitdb==4.0.11 94 | - gitpython==3.1.41 95 | - google-auth==2.26.2 96 | - google-auth-oauthlib==0.4.6 97 | - google-pasta==0.2.0 98 | - gradio==3.43.2 99 | - gradio-client==0.5.0 100 | - grpcio==1.60.0 101 | - h11==0.14.0 102 | - h5py==3.10.0 103 | - httpcore==1.0.2 104 | - httpx==0.26.0 105 | - huggingface-hub==0.20.2 106 | - humanfriendly==10.0 107 | - idna==3.6 108 | - imageio==2.31.1 109 | - imageio-ffmpeg==0.4.8 110 | - importlib-metadata==7.0.1 111 | - importlib-resources==6.1.1 112 | - iniconfig==2.0.0 113 | - insightface==0.7.3 114 | - invisible-watermark==0.1.5 115 | - iopath==0.1.10 116 | - ip-adapter==0.1.0 117 | - iprogress==0.4 118 | - ipykernel==6.29.0 119 | - ipython==8.20.0 120 | - ipywidgets==8.0.3 121 | - jax==0.4.23 122 | - jedi==0.19.1 123 | - jinja2==3.1.3 124 | - jmespath==0.10.0 125 | - joblib==1.3.2 126 | - json-tricks==3.17.3 127 | - jsonschema==4.21.0 128 | - jsonschema-specifications==2023.12.1 129 | - jupyter-client==8.6.0 130 | - jupyter-core==5.7.1 131 | - jupyterlab-widgets==3.0.9 132 | - keras==2.12.0 133 | - kiwisolver==1.4.5 134 | - kornia==0.7.0 135 | - lazy-loader==0.3 136 | - libclang==16.0.6 137 | - librosa==0.10.1 138 | - lightning-utilities==0.10.0 139 | - lit==17.0.6 140 | - llvmlite==0.41.1 141 | - lmdb==1.4.1 142 | - loguru==0.6.0 143 | - markdown==3.5.2 144 | - markdown-it-py==3.0.0 145 | - markupsafe==2.0.1 146 | - matplotlib==3.6.2 147 | - matplotlib-inline==0.1.6 148 | - mdurl==0.1.2 149 | - mediapipe==0.10.3 150 | - ml-dtypes==0.3.2 151 | - model-index==0.1.11 152 | - modelcards==0.1.6 153 | - moviepy==1.0.3 154 | - mpmath==1.3.0 155 | - msgpack==1.0.7 156 | - multidict==6.0.4 157 | - munkres==1.1.4 158 | - mypy-extensions==1.0.0 159 | - nest-asyncio==1.5.9 160 | - networkx==3.2.1 161 | - ninja==1.11.1 162 | - numba==0.58.1 163 | - numpy==1.23.5 164 | - oauthlib==3.2.2 165 | - omegaconf==2.3.0 166 | - onnx==1.14.1 167 | - onnxruntime==1.15.1 168 | - onnxsim==0.4.33 169 | - open-clip-torch==2.20.0 170 | - opencv-contrib-python==4.8.0.76 171 | - opencv-python==4.9.0.80 172 | - opencv-python-headless==4.9.0.80 173 | - opendatalab==0.0.10 174 | - openmim==0.3.9 175 | - openxlab==0.0.34 176 | - opt-einsum==3.3.0 177 | - ordered-set==4.1.0 178 | - orjson==3.9.10 179 | - oss2==2.17.0 180 | - packaging==23.2 181 | - pandas==2.1.4 182 | - parso==0.8.3 183 | - pathspec==0.12.1 184 | - pathtools==0.1.2 185 | - pexpect==4.9.0 186 | - pillow==10.2.0 187 | - pip==23.3.1 188 | - platformdirs==4.1.0 189 | - pluggy==1.3.0 190 | - pooch==1.8.0 191 | - portalocker==2.8.2 192 | - prettytable==3.9.0 193 | - proglog==0.1.10 194 | - prompt-toolkit==3.0.43 195 | - protobuf==3.20.3 196 | - psutil==5.9.7 197 | - ptyprocess==0.7.0 198 | - pure-eval==0.2.2 199 | - pyarrow==14.0.2 200 | - pyasn1==0.5.1 201 | - pyasn1-modules==0.3.0 202 | - pycocotools==2.0.7 203 | - pycparser==2.21 204 | - pycryptodome==3.20.0 205 | - pydantic==1.10.2 206 | - pydeck==0.8.1b0 207 | - pydub==0.25.1 208 | - pygments==2.17.2 209 | - pynvml==11.5.0 210 | - pyparsing==3.1.1 211 | - pysocks==1.7.1 212 | - pytest==7.4.4 213 | - python-dateutil==2.8.2 214 | - python-dotenv==1.0.0 215 | - python-multipart==0.0.6 216 | - pytorch-lightning==2.0.8 217 | - pytube==15.0.0 218 | - pytz==2023.3.post1 219 | - pywavelets==1.5.0 220 | - pyyaml==6.0.1 221 | - pyzmq==25.1.2 222 | - qudida==0.0.4 223 | - redis==4.5.1 224 | - referencing==0.32.1 225 | - regex==2023.12.25 226 | - requests==2.28.2 227 | - requests-oauthlib==1.3.1 228 | - rich==13.4.2 229 | - rpds-py==0.17.1 230 | - rsa==4.9 231 | - safetensors==0.3.3 232 | - scikit-image==0.22.0 233 | - scikit-learn==1.3.2 234 | - scipy==1.11.4 235 | - semantic-version==2.10.0 236 | - sentencepiece==0.1.99 237 | - sentry-sdk==1.39.2 238 | - setproctitle==1.3.3 239 | - setuptools==60.2.0 240 | - shapely==2.0.2 241 | - six==1.16.0 242 | - smmap==5.0.1 243 | - sniffio==1.3.0 244 | - sounddevice==0.4.6 245 | - soundfile==0.12.1 246 | - soupsieve==2.5 247 | - soxr==0.3.7 248 | - stack-data==0.6.3 249 | - starlette==0.35.1 250 | - streamlit==1.30.0 251 | - streamlit-drawable-canvas==0.9.3 252 | - sympy==1.12 253 | - tabulate==0.9.0 254 | - tb-nightly==2.11.0a20220906 255 | - tenacity==8.2.3 256 | - tensorboard==2.12.0 257 | - tensorboard-data-server==0.6.1 258 | - tensorboard-plugin-wit==1.8.1 259 | - tensorflow==2.12.0 260 | - tensorflow-estimator==2.12.0 261 | - tensorflow-io-gcs-filesystem==0.35.0 262 | - termcolor==2.4.0 263 | - terminaltables==3.1.10 264 | - test-tube==0.7.5 265 | - threadpoolctl==3.2.0 266 | - tifffile==2023.12.9 267 | - timm==0.9.12 268 | - tokenizers==0.13.3 269 | - toml==0.10.2 270 | - tomli==2.0.1 271 | - toolz==0.12.0 272 | - torch==2.0.1+cu118 273 | - torch-tb-profiler==0.4.1 274 | - torchmetrics==1.1.1 275 | - torchvision==0.15.2+cu118 276 | - tornado==6.4 277 | - tqdm==4.65.2 278 | - traitlets==5.14.1 279 | - transformers==4.33.1 280 | - triton==2.0.0 281 | - typing-extensions==4.9.0 282 | - tzdata==2023.4 283 | - tzlocal==5.2 284 | - urllib3==1.26.18 285 | - urwid==2.4.2 286 | - uvicorn==0.26.0 287 | - validators==0.22.0 288 | - wandb==0.15.10 289 | - watchdog==3.0.0 290 | - wcwidth==0.2.13 291 | - webdataset==0.2.86 292 | - webp==0.3.0 293 | - websockets==11.0.3 294 | - werkzeug==3.0.1 295 | - wget==3.2 296 | - wheel==0.41.2 297 | - widgetsnbextension==4.0.9 298 | - wrapt==1.14.1 299 | - xformers==0.0.21 300 | - xmltodict==0.13.0 301 | - xtcocotools==1.14.3 302 | - yacs==0.1.8 303 | - yapf==0.40.2 304 | - yarl==1.9.4 305 | - zipp==3.17.0 306 | - zope-interface==6.1 307 | - fire==0.6.0 308 | - cuid 309 | - git+https://github.com/tencent-ailab/IP-Adapter.git@main 310 | - git+https://github.com/openai/CLIP.git@main 311 | prefix: /data/miniconda3/envs/musev 312 | 313 | -------------------------------------------------------------------------------- /musev/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import logging.config 4 | 5 | # 读取日志配置文件内容 6 | logging.config.fileConfig(os.path.join(os.path.dirname(__file__), "logging.conf")) 7 | 8 | # 创建一个日志器logger 9 | logger = logging.getLogger("musev") 10 | -------------------------------------------------------------------------------- /musev/auto_prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/musev/auto_prompt/__init__.py -------------------------------------------------------------------------------- /musev/auto_prompt/attributes/__init__.py: -------------------------------------------------------------------------------- 1 | from ...utils.register import Register 2 | 3 | AttrRegister = Register(registry_name="attributes") 4 | 5 | # must import like bellow to ensure that each class is registered with AttrRegister: 6 | from .human import * 7 | from .render import * 8 | from .style import * 9 | -------------------------------------------------------------------------------- /musev/auto_prompt/attributes/attr2template.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 中文 3 | 该模块将关键词字典转化为描述文本,生成完整的提词,从而降低对比实验成本、提升控制能力和效率。 4 | 提词(prompy)对比实验会需要控制关键属性发生变化、其他属性不变的文本对。当需要控制的属性变量发生较大变化时,靠人为复制粘贴进行完成文本撰写工作量会非常大。 5 | 该模块主要有三种类,分别是: 6 | 1. `BaseAttribute2Text`: 单属性文本转换类 7 | 2. `MultiAttr2Text` 多属性文本转化类,输出`List[Tuple[str, str]`。具体如何转换为文本在 `MultiAttr2PromptTemplate`中实现。 8 | 3. `MultiAttr2PromptTemplate`:先将2生成的多属性文本字典列表转化为完整的文本,然后再使用内置的模板`template`拼接。拼接后的文本作为实际模型输入的提词。 9 | 1. `template`字段若没有{},且有字符,则认为输入就是完整输入网络的`prompt`; 10 | 2. `template`字段若含有{key},则认为是带关键词的字符串目标,多个属性由`template`字符串中顺序完全决定。关键词内容由表格中相关列通过`attr2text`转化而来; 11 | 3. `template`字段有且只含有一个{},如`a portrait of {}`,则相关内容由 `PresetMultiAttr2PromptTemplate`中预定义好的`attrs`列表指定先后顺序; 12 | 13 | English 14 | This module converts a keyword dictionary into descriptive text, generating complete prompts to reduce the cost of comparison experiments, and improve control and efficiency. 15 | 16 | Prompt-based comparison experiments require text pairs where the key attributes are controlled while other attributes remain constant. When the variable attributes to be controlled undergo significant changes, manually copying and pasting to write text can be very time-consuming. 17 | 18 | This module mainly consists of three classes: 19 | 20 | BaseAttribute2Text: A class for converting single attribute text. 21 | MultiAttr2Text: A class for converting multi-attribute text, outputting List[Tuple[str, str]]. The specific implementation of how to convert to text is implemented in MultiAttr2PromptTemplate. 22 | MultiAttr2PromptTemplate: First, the list of multi-attribute text dictionaries generated by 2 is converted into complete text, and then the built-in template template is used for concatenation. The concatenated text serves as the prompt for the actual model input. 23 | If the template field does not contain {}, and there are characters, the input is considered the complete prompt for the network. 24 | If the template field contains {key}, it is considered a string target with keywords, and the order of multiple attributes is completely determined by the template string. The keyword content is generated by attr2text from the relevant columns in the table. 25 | If the template field contains only one {}, such as a portrait of {}, the relevant content is specified in the order defined by the attrs list predefined in PresetMultiAttr2PromptTemplate. 26 | """ 27 | 28 | from typing import List, Tuple, Union 29 | 30 | from mmcm.utils.str_util import ( 31 | has_key_brace, 32 | merge_near_same_char, 33 | get_word_from_key_brace_string, 34 | ) 35 | 36 | from .attributes import MultiAttr2Text, merge_multi_attrtext, AttriributeIsText 37 | from . import AttrRegister 38 | 39 | 40 | class MultiAttr2PromptTemplate(object): 41 | """ 42 | 将多属性转化为模型输入文本的实际类 43 | The actual class that converts multiple attributes into model input text is 44 | """ 45 | 46 | def __init__( 47 | self, 48 | template: str, 49 | attr2text: MultiAttr2Text, 50 | name: str, 51 | ) -> None: 52 | """ 53 | Args: 54 | template (str): 提词模板, prompt template. 55 | 如果`template`含有{key},则根据key来取值。 if the template field contains {key}, it means that the actual value for that part of the prompt will be determined by the corresponding key 56 | 如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。if the template field in MultiAttr2PromptTemplate contains only one {} placeholder, such as "a portrait of {}", the order of the attributes is determined by the attrs list predefined in PresetMultiAttr2PromptTemplate. The values of the attributes in the texts list are concatenated in the order specified by the attrs list. 57 | attr2text (MultiAttr2Text): 多属性转换类。Class for converting multiple attributes into text prompt. 58 | name (str): 该多属性文本模板类的名字,便于记忆. Class Instance name 59 | """ 60 | self.attr2text = attr2text 61 | self.name = name 62 | if template == "": 63 | template = "{}" 64 | self.template = template 65 | self.template_has_key_brace = has_key_brace(template) 66 | 67 | def __call__(self, attributes: dict) -> Union[str, List[str]]: 68 | texts = self.attr2text(attributes) 69 | if not isinstance(texts, list): 70 | texts = [texts] 71 | prompts = [merge_multi_attrtext(text, self.template) for text in texts] 72 | prompts = [merge_near_same_char(prompt) for prompt in prompts] 73 | if len(prompts) == 1: 74 | prompts = prompts[0] 75 | return prompts 76 | 77 | 78 | class KeywordMultiAttr2PromptTemplate(MultiAttr2PromptTemplate): 79 | def __init__(self, template: str, name: str = "keywords") -> None: 80 | """关键词模板属性2文本转化类 81 | 1. 获取关键词模板字符串中的关键词属性; 82 | 2. 从import * 存储在locals()中变量中获取对应的类; 83 | 3. 将集成了多属性转换类的`MultiAttr2Text` 84 | Args: 85 | template (str): 含有{key}的模板字符串 86 | name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "keywords". 87 | 88 | class for converting keyword template attributes to text 89 | 1. Get the keyword attributes in the keyword template string; 90 | 2. Get the corresponding class from the variables stored in locals() by import *; 91 | 3. The `MultiAttr2Text` integrated with multiple attribute conversion classes 92 | Args: 93 | template (str): template string containing {key} 94 | name (str, optional): the name of the template string, no actual use. Defaults to "keywords". 95 | """ 96 | assert has_key_brace( 97 | template 98 | ), "template should have key brace, but given {}".format(template) 99 | keywords = get_word_from_key_brace_string(template) 100 | funcs = [] 101 | for word in keywords: 102 | if word in AttrRegister: 103 | func = AttrRegister[word](name=word) 104 | else: 105 | func = AttriributeIsText(name=word) 106 | funcs.append(func) 107 | attr2text = MultiAttr2Text(funcs, name=name) 108 | super().__init__(template, attr2text, name) 109 | 110 | 111 | class OnlySpacePromptTemplate(MultiAttr2PromptTemplate): 112 | def __init__(self, template: str, name: str = "space_prompt") -> None: 113 | """纯空模板,无论输入啥,都只返回空格字符串作为prompt。 114 | Args: 115 | template (str): 符合只输出空格字符串的模板, 116 | name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "space_prompt". 117 | 118 | Pure empty template, no matter what the input is, it will only return a space string as the prompt. 119 | Args: 120 | template (str): template that only outputs a space string, 121 | name (str, optional): the name of the template string, no actual use. Defaults to "space_prompt". 122 | """ 123 | attr2text = None 124 | super().__init__(template, attr2text, name) 125 | 126 | def __call__(self, attributes: dict) -> Union[str, List[str]]: 127 | return "" 128 | -------------------------------------------------------------------------------- /musev/auto_prompt/attributes/attributes.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import List, Tuple, Dict 3 | 4 | from mmcm.utils.str_util import has_key_brace 5 | 6 | 7 | class BaseAttribute2Text(object): 8 | """ 9 | 属性转化为文本的基类,该类作用就是输入属性,转化为描述文本。 10 | Base class for converting attributes to text which converts attributes to prompt text. 11 | """ 12 | 13 | name = "base_attribute" 14 | 15 | def __init__(self, name: str = None) -> None: 16 | """这里类实例初始化设置`name`参数,主要是为了便于一些没有提前实现、通过字符串参数实现的新属性。 17 | Theses class instances are initialized with the `name` parameter to facilitate the implementation of new attributes that are not implemented in advance and are implemented through string parameters. 18 | 19 | Args: 20 | name (str, optional): _description_. Defaults to None. 21 | """ 22 | if name is not None: 23 | self.name = name 24 | 25 | def __call__(self, attributes) -> str: 26 | raise NotImplementedError 27 | 28 | 29 | class AttributeIsTextAndName(BaseAttribute2Text): 30 | """ 31 | 属性文本转换功能类,将key和value拼接在一起作为文本. 32 | class for converting attributes to text which concatenates the key and value together as text. 33 | """ 34 | 35 | name = "attribute_is_text_name" 36 | 37 | def __call__(self, attributes) -> str: 38 | if attributes == "" or attributes is None: 39 | return "" 40 | attributes = attributes.split(",") 41 | text = ", ".join( 42 | [ 43 | "{} {}".format(attr, self.name) if attr != "" else "" 44 | for attr in attributes 45 | ] 46 | ) 47 | return text 48 | 49 | 50 | class AttriributeIsText(BaseAttribute2Text): 51 | """ 52 | 属性文本转换功能类,将value作为文本. 53 | class for converting attributes to text which only uses the value as text. 54 | """ 55 | 56 | name = "attribute_is_text" 57 | 58 | def __call__(self, attributes: str) -> str: 59 | if attributes == "" or attributes is None: 60 | return "" 61 | attributes = str(attributes) 62 | attributes = attributes.split(",") 63 | text = ", ".join(["{}".format(attr) for attr in attributes]) 64 | return text 65 | 66 | 67 | class MultiAttr2Text(object): 68 | """将多属性组成的字典转换成完整的文本描述,目前采用简单的前后拼接方式,以`, `作为拼接符号 69 | class for converting a dictionary of multiple attributes into a complete text description. Currently, a simple front and back splicing method is used, with `, ` as the splicing symbol. 70 | 71 | Args: 72 | object (_type_): _description_ 73 | """ 74 | 75 | def __init__(self, funcs: list, name) -> None: 76 | """ 77 | Args: 78 | funcs (list): 继承`BaseAttribute2Text`并实现了`__call__`函数的类. Inherited `BaseAttribute2Text` and implemented the `__call__` function of the class. 79 | name (_type_): 该多属性的一个名字,可通过该类方便了解对应相关属性都是关于啥的。 name of the multi-attribute, which can be used to easily understand what the corresponding related attributes are about. 80 | """ 81 | if not isinstance(funcs, list): 82 | funcs = [funcs] 83 | self.funcs = funcs 84 | self.name = name 85 | 86 | def __call__( 87 | self, dct: dict, ignored_blank_str: bool = False 88 | ) -> List[Tuple[str, str]]: 89 | """ 90 | 有时候一个属性可能会返回多个文本,如 style cartoon会返回宫崎骏和皮克斯两种风格,采用外积增殖成多个字典。 91 | sometimes an attribute may return multiple texts, such as style cartoon will return two styles, Miyazaki and Pixar, which are multiplied into multiple dictionaries by the outer product. 92 | Args: 93 | dct (dict): 多属性组成的字典,可能有self.funcs关注的属性也可能没有,self.funcs按照各自的名字按需提取关注的属性和值,并转化成文本. 94 | Dict of multiple attributes, may or may not have the attributes that self.funcs is concerned with. self.funcs extracts the attributes and values of interest according to their respective names and converts them into text. 95 | ignored_blank_str (bool): 如果某个attr2text返回的是空字符串,是否要过滤掉该属性。默认`False`. 96 | If the text returned by an attr2text is an empty string, whether to filter out the attribute. Defaults to `False`. 97 | Returns: 98 | Union[List[List[Tuple[str, str]]], List[Tuple[str, str]]: 多组多属性文本字典列表. Multiple sets of multi-attribute text dictionaries. 99 | """ 100 | attrs_lst = [[]] 101 | for func in self.funcs: 102 | if func.name in dct: 103 | attrs = func(dct[func.name]) 104 | if isinstance(attrs, str): 105 | for i in range(len(attrs_lst)): 106 | attrs_lst[i].append((func.name, attrs)) 107 | else: 108 | # 一个属性可能会返回多个文本 109 | n_attrs = len(attrs) 110 | new_attrs_lst = [] 111 | for n in range(n_attrs): 112 | attrs_lst_cp = deepcopy(attrs_lst) 113 | for i in range(len(attrs_lst_cp)): 114 | attrs_lst_cp[i].append((func.name, attrs[n])) 115 | new_attrs_lst.extend(attrs_lst_cp) 116 | attrs_lst = new_attrs_lst 117 | 118 | texts = [ 119 | [ 120 | (attr, text) 121 | for (attr, text) in attrs 122 | if not (text == "" and ignored_blank_str) 123 | ] 124 | for attrs in attrs_lst 125 | ] 126 | return texts 127 | 128 | 129 | def format_tuple_texts(template: str, texts: Tuple[str, str]) -> str: 130 | """使用含有"{}" 的模板对多属性文本元组进行拼接,形成新文本 131 | concatenate multiple attribute text tuples using a template containing "{}" to form a new text 132 | Args: 133 | template (str): 134 | texts (Tuple[str, str]): 多属性文本元组. multiple attribute text tuples 135 | 136 | Returns: 137 | str: 拼接后的新文本, merged new text 138 | """ 139 | merged_text = ", ".join([text[1] for text in texts if text[1] != ""]) 140 | merged_text = template.format(merged_text) 141 | return merged_text 142 | 143 | 144 | def format_dct_texts(template: str, texts: Dict[str, str]) -> str: 145 | """使用含有"{key}" 的模板对多属性文本字典进行拼接,形成新文本 146 | concatenate multiple attribute text dictionaries using a template containing "{key}" to form a new text 147 | Args: 148 | template (str): 149 | texts (Tuple[str, str]): 多属性文本字典. multiple attribute text dictionaries 150 | 151 | Returns: 152 | str: 拼接后的新文本, merged new text 153 | """ 154 | merged_text = template.format(**texts) 155 | return merged_text 156 | 157 | 158 | def merge_multi_attrtext(texts: List[Tuple[str, str]], template: str = None) -> str: 159 | """对多属性文本元组进行拼接,形成新文本。 160 | 如果`template`含有{key},则根据key来取值; 161 | 如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。 162 | 163 | concatenate multiple attribute text tuples to form a new text. 164 | if `template` contains {key}, the value is taken according to the key; 165 | if `template` contains only one {}, the values in texts are concatenated in order. 166 | Args: 167 | texts (List[Tuple[str, str]]): Tuple[str, str]第一个str是属性名,第二个str是属性转化的文本. 168 | Tuple[str, str] The first str is the attribute name, and the second str is the text of the attribute conversion. 169 | template (str, optional): template . Defaults to None. 170 | 171 | Returns: 172 | str: 拼接后的新文本, merged new text 173 | """ 174 | if not isinstance(texts, List): 175 | texts = [texts] 176 | if template is None or template == "": 177 | template = "{}" 178 | if has_key_brace(template): 179 | texts = {k: v for k, v in texts} 180 | merged_text = format_dct_texts(template, texts) 181 | else: 182 | merged_text = format_tuple_texts(template, texts) 183 | return merged_text 184 | 185 | 186 | class PresetMultiAttr2Text(MultiAttr2Text): 187 | """预置了多种关注属性转换的类,方便维护 188 | class for multiple attribute conversion with multiple attention attributes preset for easy maintenance 189 | 190 | """ 191 | 192 | preset_attributes = [] 193 | 194 | def __init__( 195 | self, funcs: List = None, use_preset: bool = True, name: str = "preset" 196 | ) -> None: 197 | """虽然预置了关注的属性列表和转换类,但也允许定义示例时,进行更新。 198 | 注意`self.preset_attributes`的元素只是类名字,以便减少实例化的资源消耗。而funcs是实例化后的属性转换列表。 199 | 200 | Although the list of attention attributes and conversion classes is preset, it is also allowed to be updated when defining an instance. 201 | Note that the elements of `self.preset_attributes` are only class names, in order to reduce the resource consumption of instantiation. And funcs is a list of instantiated attribute conversions. 202 | 203 | Args: 204 | funcs (List, optional): list of funcs . Defaults to None. 205 | use_preset (bool, optional): _description_. Defaults to True. 206 | name (str, optional): _description_. Defaults to "preset". 207 | """ 208 | if use_preset: 209 | preset_funcs = self.preset() 210 | else: 211 | preset_funcs = [] 212 | if funcs is None: 213 | funcs = [] 214 | if not isinstance(funcs, list): 215 | funcs = [funcs] 216 | funcs_names = [func.name for func in funcs] 217 | preset_funcs = [ 218 | preset_func 219 | for preset_func in preset_funcs 220 | if preset_func.name not in funcs_names 221 | ] 222 | funcs = funcs + preset_funcs 223 | super().__init__(funcs, name) 224 | 225 | def preset(self): 226 | funcs = [cls() for cls in self.preset_attributes] 227 | return funcs 228 | -------------------------------------------------------------------------------- /musev/auto_prompt/attributes/human.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | import random 4 | import json 5 | 6 | from .attributes import ( 7 | MultiAttr2Text, 8 | AttriributeIsText, 9 | AttributeIsTextAndName, 10 | PresetMultiAttr2Text, 11 | ) 12 | from .style import Style 13 | from .render import Render 14 | from . import AttrRegister 15 | 16 | 17 | __all__ = [ 18 | "Age", 19 | "Sex", 20 | "Singing", 21 | "Country", 22 | "Lighting", 23 | "Headwear", 24 | "Eyes", 25 | "Irises", 26 | "Hair", 27 | "Skin", 28 | "Face", 29 | "Smile", 30 | "Expression", 31 | "Clothes", 32 | "Nose", 33 | "Mouth", 34 | "Beard", 35 | "Necklace", 36 | "KeyWords", 37 | "InsightFace", 38 | "Caption", 39 | "Env", 40 | "Decoration", 41 | "Festival", 42 | "SpringHeadwear", 43 | "SpringClothes", 44 | "Animal", 45 | ] 46 | 47 | 48 | @AttrRegister.register 49 | class Sex(AttriributeIsText): 50 | name = "sex" 51 | 52 | def __init__(self, name: str = None) -> None: 53 | super().__init__(name) 54 | 55 | 56 | @AttrRegister.register 57 | class Headwear(AttriributeIsText): 58 | name = "headwear" 59 | 60 | def __init__(self, name: str = None) -> None: 61 | super().__init__(name) 62 | 63 | 64 | @AttrRegister.register 65 | class Expression(AttriributeIsText): 66 | name = "expression" 67 | 68 | def __init__(self, name: str = None) -> None: 69 | super().__init__(name) 70 | 71 | 72 | @AttrRegister.register 73 | class KeyWords(AttriributeIsText): 74 | name = "keywords" 75 | 76 | def __init__(self, name: str = None) -> None: 77 | super().__init__(name) 78 | 79 | 80 | @AttrRegister.register 81 | class Singing(AttriributeIsText): 82 | def __init__(self, name: str = "singing") -> None: 83 | super().__init__(name) 84 | 85 | 86 | @AttrRegister.register 87 | class Country(AttriributeIsText): 88 | name = "country" 89 | 90 | def __init__(self, name: str = None) -> None: 91 | super().__init__(name) 92 | 93 | 94 | @AttrRegister.register 95 | class Clothes(AttriributeIsText): 96 | name = "clothes" 97 | 98 | def __init__(self, name: str = None) -> None: 99 | super().__init__(name) 100 | 101 | 102 | @AttrRegister.register 103 | class Age(AttributeIsTextAndName): 104 | name = "age" 105 | 106 | def __init__(self, name: str = None) -> None: 107 | super().__init__(name) 108 | 109 | def __call__(self, attributes: str) -> str: 110 | if not isinstance(attributes, str): 111 | attributes = str(attributes) 112 | attributes = attributes.split(",") 113 | text = ", ".join( 114 | ["{}-year-old".format(attr) if attr != "" else "" for attr in attributes] 115 | ) 116 | return text 117 | 118 | 119 | @AttrRegister.register 120 | class Eyes(AttributeIsTextAndName): 121 | name = "eyes" 122 | 123 | def __init__(self, name: str = None) -> None: 124 | super().__init__(name) 125 | 126 | 127 | @AttrRegister.register 128 | class Hair(AttributeIsTextAndName): 129 | name = "hair" 130 | 131 | def __init__(self, name: str = None) -> None: 132 | super().__init__(name) 133 | 134 | 135 | @AttrRegister.register 136 | class Background(AttributeIsTextAndName): 137 | name = "background" 138 | 139 | def __init__(self, name: str = None) -> None: 140 | super().__init__(name) 141 | 142 | 143 | @AttrRegister.register 144 | class Skin(AttributeIsTextAndName): 145 | name = "skin" 146 | 147 | def __init__(self, name: str = None) -> None: 148 | super().__init__(name) 149 | 150 | 151 | @AttrRegister.register 152 | class Face(AttributeIsTextAndName): 153 | name = "face" 154 | 155 | def __init__(self, name: str = None) -> None: 156 | super().__init__(name) 157 | 158 | 159 | @AttrRegister.register 160 | class Smile(AttributeIsTextAndName): 161 | name = "smile" 162 | 163 | def __init__(self, name: str = None) -> None: 164 | super().__init__(name) 165 | 166 | 167 | @AttrRegister.register 168 | class Nose(AttributeIsTextAndName): 169 | name = "nose" 170 | 171 | def __init__(self, name: str = None) -> None: 172 | super().__init__(name) 173 | 174 | 175 | @AttrRegister.register 176 | class Mouth(AttributeIsTextAndName): 177 | name = "mouth" 178 | 179 | def __init__(self, name: str = None) -> None: 180 | super().__init__(name) 181 | 182 | 183 | @AttrRegister.register 184 | class Beard(AttriributeIsText): 185 | name = "beard" 186 | 187 | def __init__(self, name: str = None) -> None: 188 | super().__init__(name) 189 | 190 | 191 | @AttrRegister.register 192 | class Necklace(AttributeIsTextAndName): 193 | name = "necklace" 194 | 195 | def __init__(self, name: str = None) -> None: 196 | super().__init__(name) 197 | 198 | 199 | @AttrRegister.register 200 | class Irises(AttributeIsTextAndName): 201 | name = "irises" 202 | 203 | def __init__(self, name: str = None) -> None: 204 | super().__init__(name) 205 | 206 | 207 | @AttrRegister.register 208 | class Lighting(AttributeIsTextAndName): 209 | name = "lighting" 210 | 211 | def __init__(self, name: str = None) -> None: 212 | super().__init__(name) 213 | 214 | 215 | PresetPortraitAttributes = [ 216 | Age, 217 | Sex, 218 | Singing, 219 | Country, 220 | Lighting, 221 | Headwear, 222 | Eyes, 223 | Irises, 224 | Hair, 225 | Skin, 226 | Face, 227 | Smile, 228 | Expression, 229 | Clothes, 230 | Nose, 231 | Mouth, 232 | Beard, 233 | Necklace, 234 | Style, 235 | KeyWords, 236 | Render, 237 | ] 238 | 239 | 240 | class PortraitMultiAttr2Text(PresetMultiAttr2Text): 241 | preset_attributes = PresetPortraitAttributes 242 | 243 | def __init__(self, funcs: list = None, use_preset=True, name="portrait") -> None: 244 | super().__init__(funcs, use_preset, name) 245 | 246 | 247 | @AttrRegister.register 248 | class InsightFace(AttriributeIsText): 249 | name = "insight_face" 250 | face_render_dict = { 251 | "boy": "handsome,elegant", 252 | "girl": "gorgeous,kawaii,colorful", 253 | } 254 | key_words = "delicate face,beautiful eyes" 255 | 256 | def __call__(self, attributes: str) -> str: 257 | """将insight faces 检测的结果转化成prompt 258 | convert the results of insight faces detection to prompt 259 | Args: 260 | face_list (_type_): _description_ 261 | 262 | Returns: 263 | _type_: _description_ 264 | """ 265 | attributes = json.loads(attributes) 266 | face_list = attributes["info"] 267 | if len(face_list) == 0: 268 | return "" 269 | 270 | if attributes["image_type"] == "body": 271 | for face in face_list: 272 | if "black" in face and face["black"]: 273 | return "african,dark skin" 274 | return "" 275 | 276 | gender_dict = {"girl": 0, "boy": 0} 277 | face_render_list = [] 278 | black = False 279 | 280 | for face in face_list: 281 | if face["ratio"] < 0.02: 282 | continue 283 | 284 | if face["gender"] == 0: 285 | gender_dict["girl"] += 1 286 | face_render_list.append(self.face_render_dict["girl"]) 287 | else: 288 | gender_dict["boy"] += 1 289 | face_render_list.append(self.face_render_dict["boy"]) 290 | 291 | if "black" in face and face["black"]: 292 | black = True 293 | 294 | if len(face_render_list) == 0: 295 | return "" 296 | elif len(face_render_list) == 1: 297 | solo = True 298 | else: 299 | solo = False 300 | 301 | gender = "" 302 | for g, num in gender_dict.items(): 303 | if num > 0: 304 | if gender: 305 | gender += ", " 306 | gender += "{}{}".format(num, g) 307 | if num > 1: 308 | gender += "s" 309 | 310 | face_render_list = ",".join(face_render_list) 311 | face_render_list = face_render_list.split(",") 312 | face_render = list(set(face_render_list)) 313 | face_render.sort(key=face_render_list.index) 314 | face_render = ",".join(face_render) 315 | if gender_dict["girl"] == 0: 316 | face_render = "male focus," + face_render 317 | 318 | insightface_prompt = "{},{},{}".format(gender, face_render, self.key_words) 319 | 320 | if solo: 321 | insightface_prompt += ",solo" 322 | if black: 323 | insightface_prompt = "african,dark skin," + insightface_prompt 324 | 325 | return insightface_prompt 326 | 327 | 328 | @AttrRegister.register 329 | class Caption(AttriributeIsText): 330 | name = "caption" 331 | 332 | 333 | @AttrRegister.register 334 | class Env(AttriributeIsText): 335 | name = "env" 336 | envs_list = [ 337 | "east asian architecture", 338 | "fireworks", 339 | "snow, snowflakes", 340 | "snowing, snowflakes", 341 | ] 342 | 343 | def __call__(self, attributes: str = None) -> str: 344 | if attributes != "" and attributes != " " and attributes is not None: 345 | return attributes 346 | else: 347 | return random.choice(self.envs_list) 348 | 349 | 350 | @AttrRegister.register 351 | class Decoration(AttriributeIsText): 352 | name = "decoration" 353 | 354 | def __init__(self, name: str = None) -> None: 355 | self.decoration_list = [ 356 | "chinese knot", 357 | "flowers", 358 | "food", 359 | "lanterns", 360 | "red envelop", 361 | ] 362 | super().__init__(name) 363 | 364 | def __call__(self, attributes: str = None) -> str: 365 | if attributes != "" and attributes != " " and attributes is not None: 366 | return attributes 367 | else: 368 | return random.choice(self.decoration_list) 369 | 370 | 371 | @AttrRegister.register 372 | class Festival(AttriributeIsText): 373 | name = "festival" 374 | festival_list = ["new year"] 375 | 376 | def __init__(self, name: str = None) -> None: 377 | super().__init__(name) 378 | 379 | def __call__(self, attributes: str = None) -> str: 380 | if attributes != "" and attributes != " " and attributes is not None: 381 | return attributes 382 | else: 383 | return random.choice(self.festival_list) 384 | 385 | 386 | @AttrRegister.register 387 | class SpringHeadwear(AttriributeIsText): 388 | name = "spring_headwear" 389 | headwear_list = ["rabbit ears", "rabbit ears, fur hat"] 390 | 391 | def __call__(self, attributes: str = None) -> str: 392 | if attributes != "" and attributes != " " and attributes is not None: 393 | return attributes 394 | else: 395 | return random.choice(self.headwear_list) 396 | 397 | 398 | @AttrRegister.register 399 | class SpringClothes(AttriributeIsText): 400 | name = "spring_clothes" 401 | clothes_list = [ 402 | "mittens,chinese clothes", 403 | "mittens,fur trim", 404 | "mittens,red scarf", 405 | "mittens,winter clothes", 406 | ] 407 | 408 | def __call__(self, attributes: str = None) -> str: 409 | if attributes != "" and attributes != " " and attributes is not None: 410 | return attributes 411 | else: 412 | return random.choice(self.clothes_list) 413 | 414 | 415 | @AttrRegister.register 416 | class Animal(AttriributeIsText): 417 | name = "animal" 418 | animal_list = ["rabbit", "holding rabbits"] 419 | 420 | def __call__(self, attributes: str = None) -> str: 421 | if attributes != "" and attributes != " " and attributes is not None: 422 | return attributes 423 | else: 424 | return random.choice(self.animal_list) 425 | -------------------------------------------------------------------------------- /musev/auto_prompt/attributes/render.py: -------------------------------------------------------------------------------- 1 | from mmcm.utils.util import flatten 2 | 3 | from .attributes import BaseAttribute2Text 4 | from . import AttrRegister 5 | 6 | __all__ = ["Render"] 7 | 8 | RenderMap = { 9 | "Epic": "artstation, epic environment, highly detailed, 8k, HD", 10 | "HD": "8k, highly detailed", 11 | "EpicHD": "hyper detailed, beautiful lighting, epic environment, octane render, cinematic, 8k", 12 | "Digital": "detailed illustration, crisp lines, digital art, 8k, trending on artstation", 13 | "Unreal1": "artstation, concept art, smooth, sharp focus, illustration, unreal engine 5, 8k", 14 | "Unreal2": "concept art, octane render, artstation, epic environment, highly detailed, 8k", 15 | } 16 | 17 | 18 | @AttrRegister.register 19 | class Render(BaseAttribute2Text): 20 | name = "render" 21 | 22 | def __init__(self, name: str = None) -> None: 23 | super().__init__(name) 24 | 25 | def __call__(self, attributes: str) -> str: 26 | if attributes == "" or attributes is None: 27 | return "" 28 | attributes = attributes.split(",") 29 | render = [RenderMap[attr] for attr in attributes if attr in RenderMap] 30 | render = flatten(render, ignored_iterable_types=[str]) 31 | if len(render) == 1: 32 | render = render[0] 33 | return render 34 | -------------------------------------------------------------------------------- /musev/auto_prompt/attributes/style.py: -------------------------------------------------------------------------------- 1 | from .attributes import AttriributeIsText 2 | from . import AttrRegister 3 | 4 | __all__ = ["Style"] 5 | 6 | 7 | @AttrRegister.register 8 | class Style(AttriributeIsText): 9 | name = "style" 10 | 11 | def __init__(self, name: str = None) -> None: 12 | super().__init__(name) 13 | -------------------------------------------------------------------------------- /musev/auto_prompt/human.py: -------------------------------------------------------------------------------- 1 | """负责按照人相关的属性转化成提词 2 | """ 3 | from typing import List 4 | 5 | from .attributes.human import PortraitMultiAttr2Text 6 | from .attributes.attributes import BaseAttribute2Text 7 | from .attributes.attr2template import MultiAttr2PromptTemplate 8 | 9 | 10 | class PortraitAttr2PromptTemplate(MultiAttr2PromptTemplate): 11 | """可以将任务字典转化为形象提词模板类 12 | template class for converting task dictionaries into image prompt templates 13 | Args: 14 | MultiAttr2PromptTemplate (_type_): _description_ 15 | """ 16 | 17 | templates = "a portrait of {}" 18 | 19 | def __init__( 20 | self, templates: str = None, attr2text: List = None, name: str = "portrait" 21 | ) -> None: 22 | """ 23 | 24 | Args: 25 | templates (str, optional): 形象提词模板,若为None,则使用默认的类属性. Defaults to None. 26 | portrait prompt template, if None, the default class attribute is used. 27 | attr2text (List, optional): 形象类需要新增、更新的属性列表,默认使用PortraitMultiAttr2Text中定义的形象属性. Defaults to None. 28 | the list of attributes that need to be added or updated in the image class, by default, the image attributes defined in PortraitMultiAttr2Text are used. 29 | name (str, optional): 该形象类的名字. Defaults to "portrait". 30 | class name of this class instance 31 | """ 32 | if ( 33 | attr2text is None 34 | or isinstance(attr2text, list) 35 | or isinstance(attr2text, BaseAttribute2Text) 36 | ): 37 | attr2text = PortraitMultiAttr2Text(funcs=attr2text) 38 | if templates is None: 39 | templates = self.templates 40 | super().__init__(templates, attr2text, name=name) 41 | -------------------------------------------------------------------------------- /musev/auto_prompt/load_template.py: -------------------------------------------------------------------------------- 1 | from mmcm.utils.str_util import has_key_brace 2 | 3 | from .human import PortraitAttr2PromptTemplate 4 | from .attributes.attr2template import ( 5 | KeywordMultiAttr2PromptTemplate, 6 | OnlySpacePromptTemplate, 7 | ) 8 | 9 | 10 | def get_template_by_name(template: str, name: str = None): 11 | """根据 template_name 确定 prompt 生成器类 12 | choose prompt generator class according to template_name 13 | Args: 14 | name (str): template 的名字简称,便于指定. template name abbreviation, for easy reference 15 | 16 | Raises: 17 | ValueError: ValueError: 如果name不在支持的列表中,则报错. if name is not in the supported list, an error is reported. 18 | 19 | Returns: 20 | MultiAttr2PromptTemplate: 能够将任务字典转化为提词的 实现了__call__功能的类. class that can convert task dictionaries into prompts and implements the __call__ function 21 | 22 | """ 23 | if template == "" or template is None: 24 | template = OnlySpacePromptTemplate(template=template) 25 | elif has_key_brace(template): 26 | # if has_key_brace(template): 27 | template = KeywordMultiAttr2PromptTemplate(template=template) 28 | else: 29 | if name == "portrait": 30 | template = PortraitAttr2PromptTemplate(templates=template) 31 | else: 32 | raise ValueError( 33 | "PresetAttr2PromptTemplate only support one of [portrait], but given {}".format( 34 | name 35 | ) 36 | ) 37 | return template 38 | -------------------------------------------------------------------------------- /musev/auto_prompt/util.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, List 3 | 4 | from .load_template import get_template_by_name 5 | 6 | 7 | def generate_prompts(tasks: List[Dict]) -> List[Dict]: 8 | new_tasks = [] 9 | for task in tasks: 10 | task["origin_prompt"] = deepcopy(task["prompt"]) 11 | # 如果prompt单元值含有模板 {},或者 没有填写任何值(默认为空模板),则使用原prompt值 12 | if "{" not in task["prompt"] and len(task["prompt"]) != 0: 13 | new_tasks.append(task) 14 | else: 15 | template = get_template_by_name( 16 | template=task["prompt"], name=task.get("template_name", None) 17 | ) 18 | prompts = template(task) 19 | if not isinstance(prompts, list) and isinstance(prompts, str): 20 | prompts = [prompts] 21 | for prompt in prompts: 22 | task_cp = deepcopy(task) 23 | task_cp["prompt"] = prompt 24 | new_tasks.append(task_cp) 25 | return new_tasks 26 | -------------------------------------------------------------------------------- /musev/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/musev/data/__init__.py -------------------------------------------------------------------------------- /musev/logging.conf: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root,musev 3 | 4 | [handlers] 5 | keys=consoleHandler 6 | 7 | [formatters] 8 | keys=musevFormatter 9 | 10 | [logger_root] 11 | level=INFO 12 | handlers=consoleHandler 13 | 14 | # logger level 尽量设置低一点 15 | [logger_musev] 16 | level=DEBUG 17 | handlers=consoleHandler 18 | qualname=musev 19 | propagate=0 20 | 21 | # handler level 设置比 logger level高 22 | [handler_consoleHandler] 23 | class=StreamHandler 24 | level=DEBUG 25 | # level=INFO 26 | 27 | formatter=musevFormatter 28 | args=(sys.stdout,) 29 | 30 | [formatter_musevFormatter] 31 | format=%(asctime)s- %(name)s:%(lineno)d- %(levelname)s- %(message)s 32 | datefmt= -------------------------------------------------------------------------------- /musev/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils.register import Register 2 | 3 | Model_Register = Register(registry_name="torch_model") 4 | -------------------------------------------------------------------------------- /musev/models/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from einops import rearrange 16 | import torch 17 | from torch.nn import functional as F 18 | import numpy as np 19 | 20 | from diffusers.models.embeddings import get_2d_sincos_pos_embed_from_grid 21 | 22 | 23 | # ref diffusers.models.embeddings.get_2d_sincos_pos_embed 24 | def get_2d_sincos_pos_embed( 25 | embed_dim, 26 | grid_size_w, 27 | grid_size_h, 28 | cls_token=False, 29 | extra_tokens=0, 30 | norm_length: bool = False, 31 | max_length: float = 2048, 32 | ): 33 | """ 34 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 35 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 36 | """ 37 | if norm_length and grid_size_h <= max_length and grid_size_w <= max_length: 38 | grid_h = np.linspace(0, max_length, grid_size_h) 39 | grid_w = np.linspace(0, max_length, grid_size_w) 40 | else: 41 | grid_h = np.arange(grid_size_h, dtype=np.float32) 42 | grid_w = np.arange(grid_size_w, dtype=np.float32) 43 | grid = np.meshgrid(grid_h, grid_w) # here h goes first 44 | grid = np.stack(grid, axis=0) 45 | 46 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) 47 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 48 | if cls_token and extra_tokens > 0: 49 | pos_embed = np.concatenate( 50 | [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 51 | ) 52 | return pos_embed 53 | 54 | 55 | def resize_spatial_position_emb( 56 | emb: torch.Tensor, 57 | height: int, 58 | width: int, 59 | scale: float = None, 60 | target_height: int = None, 61 | target_width: int = None, 62 | ) -> torch.Tensor: 63 | """_summary_ 64 | 65 | Args: 66 | emb (torch.Tensor): b ( h w) d 67 | height (int): _description_ 68 | width (int): _description_ 69 | scale (float, optional): _description_. Defaults to None. 70 | target_height (int, optional): _description_. Defaults to None. 71 | target_width (int, optional): _description_. Defaults to None. 72 | 73 | Returns: 74 | torch.Tensor: b (target_height target_width) d 75 | """ 76 | if scale is not None: 77 | target_height = int(height * scale) 78 | target_width = int(width * scale) 79 | emb = rearrange(emb, "(h w) (b d) ->b d h w", h=height, b=1) 80 | emb = F.interpolate( 81 | emb, 82 | size=(target_height, target_width), 83 | mode="bicubic", 84 | align_corners=False, 85 | ) 86 | emb = rearrange(emb, "b d h w-> (h w) (b d)") 87 | return emb 88 | -------------------------------------------------------------------------------- /musev/models/facein_loader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Callable, Dict, Iterable, Union 3 | import PIL 4 | import cv2 5 | import torch 6 | import argparse 7 | import datetime 8 | import logging 9 | import inspect 10 | import math 11 | import os 12 | import shutil 13 | from typing import Dict, List, Optional, Tuple 14 | from pprint import pprint 15 | from collections import OrderedDict 16 | from dataclasses import dataclass 17 | import gc 18 | import time 19 | 20 | import numpy as np 21 | from omegaconf import OmegaConf 22 | from omegaconf import SCMode 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | from einops import rearrange, repeat 28 | import pandas as pd 29 | import h5py 30 | from diffusers.models.modeling_utils import load_state_dict 31 | from diffusers.utils import ( 32 | logging, 33 | ) 34 | from diffusers.utils.import_utils import is_xformers_available 35 | 36 | from mmcm.vision.feature_extractor.clip_vision_extractor import ( 37 | ImageClipVisionFeatureExtractor, 38 | ImageClipVisionFeatureExtractorV2, 39 | ) 40 | from mmcm.vision.feature_extractor.insight_face_extractor import InsightFaceExtractor 41 | 42 | from ip_adapter.resampler import Resampler 43 | from ip_adapter.ip_adapter import ImageProjModel 44 | 45 | from .unet_loader import update_unet_with_sd 46 | from .unet_3d_condition import UNet3DConditionModel 47 | from .ip_adapter_loader import ip_adapter_keys_list 48 | 49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 50 | 51 | 52 | # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 53 | unet_keys_list = [ 54 | "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 55 | "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 56 | "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 57 | "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 58 | "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 59 | "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 60 | "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 61 | "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 62 | "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 63 | "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 64 | "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 65 | "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 66 | "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 67 | "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 68 | "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 69 | "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 70 | "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 71 | "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 72 | "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 73 | "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 74 | "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 75 | "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 76 | "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 77 | "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 78 | "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 79 | "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 80 | "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 81 | "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 82 | "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 83 | "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 84 | "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", 85 | "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", 86 | ] 87 | 88 | 89 | UNET2IPAadapter_Keys_MAPIING = { 90 | k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) 91 | } 92 | 93 | 94 | def load_facein_extractor_and_proj_by_name( 95 | model_name: str, 96 | ip_ckpt: Tuple[str, nn.Module], 97 | ip_image_encoder: Tuple[str, nn.Module] = None, 98 | cross_attention_dim: int = 768, 99 | clip_embeddings_dim: int = 512, 100 | clip_extra_context_tokens: int = 1, 101 | ip_scale: float = 0.0, 102 | dtype: torch.dtype = torch.float16, 103 | device: str = "cuda", 104 | unet: nn.Module = None, 105 | ) -> nn.Module: 106 | pass 107 | 108 | 109 | def update_unet_facein_cross_attn_param( 110 | unet: UNet3DConditionModel, ip_adapter_state_dict: Dict 111 | ) -> None: 112 | """use independent ip_adapter attn 中的 to_k, to_v in unet 113 | ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']的字典 114 | 115 | 116 | Args: 117 | unet (UNet3DConditionModel): _description_ 118 | ip_adapter_state_dict (Dict): _description_ 119 | """ 120 | pass 121 | -------------------------------------------------------------------------------- /musev/models/ip_adapter_face_loader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Callable, Dict, Iterable, Union 3 | import PIL 4 | import cv2 5 | import torch 6 | import argparse 7 | import datetime 8 | import logging 9 | import inspect 10 | import math 11 | import os 12 | import shutil 13 | from typing import Dict, List, Optional, Tuple 14 | from pprint import pprint 15 | from collections import OrderedDict 16 | from dataclasses import dataclass 17 | import gc 18 | import time 19 | 20 | import numpy as np 21 | from omegaconf import OmegaConf 22 | from omegaconf import SCMode 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | from einops import rearrange, repeat 28 | import pandas as pd 29 | import h5py 30 | from diffusers.models.modeling_utils import load_state_dict 31 | from diffusers.utils import ( 32 | logging, 33 | ) 34 | from diffusers.utils.import_utils import is_xformers_available 35 | 36 | from ip_adapter.resampler import Resampler 37 | from ip_adapter.ip_adapter import ImageProjModel 38 | from ip_adapter.ip_adapter_faceid import ProjPlusModel, MLPProjModel 39 | 40 | from mmcm.vision.feature_extractor.clip_vision_extractor import ( 41 | ImageClipVisionFeatureExtractor, 42 | ImageClipVisionFeatureExtractorV2, 43 | ) 44 | from mmcm.vision.feature_extractor.insight_face_extractor import ( 45 | InsightFaceExtractorNormEmb, 46 | ) 47 | 48 | 49 | from .unet_loader import update_unet_with_sd 50 | from .unet_3d_condition import UNet3DConditionModel 51 | from .ip_adapter_loader import ip_adapter_keys_list 52 | 53 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 54 | 55 | 56 | # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 57 | unet_keys_list = [ 58 | "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 59 | "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 60 | "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 61 | "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 62 | "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 63 | "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 64 | "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 65 | "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 66 | "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 67 | "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 68 | "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 69 | "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 70 | "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 71 | "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 72 | "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 73 | "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 74 | "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 75 | "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 76 | "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 77 | "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 78 | "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 79 | "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 80 | "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 81 | "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 82 | "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 83 | "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 84 | "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 85 | "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 86 | "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 87 | "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 88 | "mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", 89 | "mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", 90 | ] 91 | 92 | 93 | UNET2IPAadapter_Keys_MAPIING = { 94 | k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) 95 | } 96 | 97 | 98 | def load_ip_adapter_face_extractor_and_proj_by_name( 99 | model_name: str, 100 | ip_ckpt: Tuple[str, nn.Module], 101 | ip_image_encoder: Tuple[str, nn.Module] = None, 102 | cross_attention_dim: int = 768, 103 | clip_embeddings_dim: int = 1024, 104 | clip_extra_context_tokens: int = 4, 105 | ip_scale: float = 0.0, 106 | dtype: torch.dtype = torch.float16, 107 | device: str = "cuda", 108 | unet: nn.Module = None, 109 | ) -> nn.Module: 110 | if model_name == "IPAdapterFaceID": 111 | if ip_image_encoder is not None: 112 | ip_adapter_face_emb_extractor = InsightFaceExtractorNormEmb( 113 | pretrained_model_name_or_path=ip_image_encoder, 114 | dtype=dtype, 115 | device=device, 116 | ) 117 | else: 118 | ip_adapter_face_emb_extractor = None 119 | ip_adapter_image_proj = MLPProjModel( 120 | cross_attention_dim=cross_attention_dim, 121 | id_embeddings_dim=clip_embeddings_dim, 122 | num_tokens=clip_extra_context_tokens, 123 | ).to(device, dtype=dtype) 124 | else: 125 | raise ValueError( 126 | f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, IPAdapterFaceID" 127 | ) 128 | ip_adapter_state_dict = torch.load( 129 | ip_ckpt, 130 | map_location="cpu", 131 | ) 132 | ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) 133 | if unet is not None and "ip_adapter" in ip_adapter_state_dict: 134 | update_unet_ip_adapter_cross_attn_param( 135 | unet, 136 | ip_adapter_state_dict["ip_adapter"], 137 | ) 138 | logger.info( 139 | f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" 140 | ) 141 | return ( 142 | ip_adapter_face_emb_extractor, 143 | ip_adapter_image_proj, 144 | ) 145 | 146 | 147 | def update_unet_ip_adapter_cross_attn_param( 148 | unet: UNet3DConditionModel, ip_adapter_state_dict: Dict 149 | ) -> None: 150 | """use independent ip_adapter attn 中的 to_k, to_v in unet 151 | ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight'] 152 | 153 | 154 | Args: 155 | unet (UNet3DConditionModel): _description_ 156 | ip_adapter_state_dict (Dict): _description_ 157 | """ 158 | unet_spatial_cross_atnns = unet.spatial_cross_attns[0] 159 | unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns} 160 | for i, (unet_key_more, ip_adapter_key) in enumerate( 161 | UNET2IPAadapter_Keys_MAPIING.items() 162 | ): 163 | ip_adapter_value = ip_adapter_state_dict[ip_adapter_key] 164 | unet_key_more_spit = unet_key_more.split(".") 165 | unet_key = ".".join(unet_key_more_spit[:-3]) 166 | suffix = ".".join(unet_key_more_spit[-3:]) 167 | logger.debug( 168 | f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}", 169 | ) 170 | if ".ip_adapter_face_to_k" in suffix: 171 | with torch.no_grad(): 172 | unet_spatial_cross_atnns_dct[ 173 | unet_key 174 | ].ip_adapter_face_to_k_ip.weight.copy_(ip_adapter_value.data) 175 | else: 176 | with torch.no_grad(): 177 | unet_spatial_cross_atnns_dct[ 178 | unet_key 179 | ].ip_adapter_face_to_v_ip.weight.copy_(ip_adapter_value.data) 180 | -------------------------------------------------------------------------------- /musev/models/ip_adapter_loader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Callable, Dict, Iterable, Union 3 | import PIL 4 | import cv2 5 | import torch 6 | import argparse 7 | import datetime 8 | import logging 9 | import inspect 10 | import math 11 | import os 12 | import shutil 13 | from typing import Dict, List, Optional, Tuple 14 | from pprint import pprint 15 | from collections import OrderedDict 16 | from dataclasses import dataclass 17 | import gc 18 | import time 19 | 20 | import numpy as np 21 | from omegaconf import OmegaConf 22 | from omegaconf import SCMode 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | from einops import rearrange, repeat 28 | import pandas as pd 29 | import h5py 30 | from diffusers.models.modeling_utils import load_state_dict 31 | from diffusers.utils import ( 32 | logging, 33 | ) 34 | from diffusers.utils.import_utils import is_xformers_available 35 | 36 | from mmcm.vision.feature_extractor import clip_vision_extractor 37 | from mmcm.vision.feature_extractor.clip_vision_extractor import ( 38 | ImageClipVisionFeatureExtractor, 39 | ImageClipVisionFeatureExtractorV2, 40 | VerstailSDLastHiddenState2ImageEmb, 41 | ) 42 | 43 | from ip_adapter.resampler import Resampler 44 | from ip_adapter.ip_adapter import ImageProjModel 45 | 46 | from .unet_loader import update_unet_with_sd 47 | from .unet_3d_condition import UNet3DConditionModel 48 | 49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 50 | 51 | 52 | def load_vision_clip_encoder_by_name( 53 | ip_image_encoder: Tuple[str, nn.Module] = None, 54 | dtype: torch.dtype = torch.float16, 55 | device: str = "cuda", 56 | vision_clip_extractor_class_name: str = None, 57 | ) -> nn.Module: 58 | if vision_clip_extractor_class_name is not None: 59 | vision_clip_extractor = getattr( 60 | clip_vision_extractor, vision_clip_extractor_class_name 61 | )( 62 | pretrained_model_name_or_path=ip_image_encoder, 63 | dtype=dtype, 64 | device=device, 65 | ) 66 | else: 67 | vision_clip_extractor = None 68 | return vision_clip_extractor 69 | 70 | 71 | def load_ip_adapter_image_proj_by_name( 72 | model_name: str, 73 | ip_ckpt: Tuple[str, nn.Module] = None, 74 | cross_attention_dim: int = 768, 75 | clip_embeddings_dim: int = 1024, 76 | clip_extra_context_tokens: int = 4, 77 | ip_scale: float = 0.0, 78 | dtype: torch.dtype = torch.float16, 79 | device: str = "cuda", 80 | unet: nn.Module = None, 81 | vision_clip_extractor_class_name: str = None, 82 | ip_image_encoder: Tuple[str, nn.Module] = None, 83 | ) -> nn.Module: 84 | if model_name in [ 85 | "IPAdapter", 86 | "musev_referencenet", 87 | "musev_referencenet_pose", 88 | ]: 89 | ip_adapter_image_proj = ImageProjModel( 90 | cross_attention_dim=cross_attention_dim, 91 | clip_embeddings_dim=clip_embeddings_dim, 92 | clip_extra_context_tokens=clip_extra_context_tokens, 93 | ) 94 | 95 | elif model_name == "IPAdapterPlus": 96 | vision_clip_extractor = ImageClipVisionFeatureExtractorV2( 97 | pretrained_model_name_or_path=ip_image_encoder, 98 | dtype=dtype, 99 | device=device, 100 | ) 101 | ip_adapter_image_proj = Resampler( 102 | dim=cross_attention_dim, 103 | depth=4, 104 | dim_head=64, 105 | heads=12, 106 | num_queries=clip_extra_context_tokens, 107 | embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size, 108 | output_dim=cross_attention_dim, 109 | ff_mult=4, 110 | ) 111 | elif model_name in [ 112 | "VerstailSDLastHiddenState2ImageEmb", 113 | "OriginLastHiddenState2ImageEmbd", 114 | "OriginLastHiddenState2Poolout", 115 | ]: 116 | ip_adapter_image_proj = getattr( 117 | clip_vision_extractor, model_name 118 | ).from_pretrained(ip_image_encoder) 119 | else: 120 | raise ValueError( 121 | f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, VerstailSDLastHiddenState2ImageEmb" 122 | ) 123 | if ip_ckpt is not None: 124 | ip_adapter_state_dict = torch.load( 125 | ip_ckpt, 126 | map_location="cpu", 127 | ) 128 | ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) 129 | if ( 130 | unet is not None 131 | and unet.ip_adapter_cross_attn 132 | and "ip_adapter" in ip_adapter_state_dict 133 | ): 134 | update_unet_ip_adapter_cross_attn_param( 135 | unet, ip_adapter_state_dict["ip_adapter"] 136 | ) 137 | logger.info( 138 | f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" 139 | ) 140 | return ip_adapter_image_proj 141 | 142 | 143 | def load_ip_adapter_vision_clip_encoder_by_name( 144 | model_name: str, 145 | ip_ckpt: Tuple[str, nn.Module], 146 | ip_image_encoder: Tuple[str, nn.Module] = None, 147 | cross_attention_dim: int = 768, 148 | clip_embeddings_dim: int = 1024, 149 | clip_extra_context_tokens: int = 4, 150 | ip_scale: float = 0.0, 151 | dtype: torch.dtype = torch.float16, 152 | device: str = "cuda", 153 | unet: nn.Module = None, 154 | vision_clip_extractor_class_name: str = None, 155 | ) -> nn.Module: 156 | if vision_clip_extractor_class_name is not None: 157 | vision_clip_extractor = getattr( 158 | clip_vision_extractor, vision_clip_extractor_class_name 159 | )( 160 | pretrained_model_name_or_path=ip_image_encoder, 161 | dtype=dtype, 162 | device=device, 163 | ) 164 | else: 165 | vision_clip_extractor = None 166 | if model_name in [ 167 | "IPAdapter", 168 | "musev_referencenet", 169 | ]: 170 | if ip_image_encoder is not None: 171 | if vision_clip_extractor_class_name is None: 172 | vision_clip_extractor = ImageClipVisionFeatureExtractor( 173 | pretrained_model_name_or_path=ip_image_encoder, 174 | dtype=dtype, 175 | device=device, 176 | ) 177 | else: 178 | vision_clip_extractor = None 179 | ip_adapter_image_proj = ImageProjModel( 180 | cross_attention_dim=cross_attention_dim, 181 | clip_embeddings_dim=clip_embeddings_dim, 182 | clip_extra_context_tokens=clip_extra_context_tokens, 183 | ) 184 | 185 | elif model_name == "IPAdapterPlus": 186 | if ip_image_encoder is not None: 187 | if vision_clip_extractor_class_name is None: 188 | vision_clip_extractor = ImageClipVisionFeatureExtractorV2( 189 | pretrained_model_name_or_path=ip_image_encoder, 190 | dtype=dtype, 191 | device=device, 192 | ) 193 | else: 194 | vision_clip_extractor = None 195 | ip_adapter_image_proj = Resampler( 196 | dim=cross_attention_dim, 197 | depth=4, 198 | dim_head=64, 199 | heads=12, 200 | num_queries=clip_extra_context_tokens, 201 | embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size, 202 | output_dim=cross_attention_dim, 203 | ff_mult=4, 204 | ).to(dtype=torch.float16) 205 | else: 206 | raise ValueError( 207 | f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus" 208 | ) 209 | ip_adapter_state_dict = torch.load( 210 | ip_ckpt, 211 | map_location="cpu", 212 | ) 213 | ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) 214 | if ( 215 | unet is not None 216 | and unet.ip_adapter_cross_attn 217 | and "ip_adapter" in ip_adapter_state_dict 218 | ): 219 | update_unet_ip_adapter_cross_attn_param( 220 | unet, ip_adapter_state_dict["ip_adapter"] 221 | ) 222 | logger.info( 223 | f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" 224 | ) 225 | return ( 226 | vision_clip_extractor, 227 | ip_adapter_image_proj, 228 | ) 229 | 230 | 231 | # refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 232 | unet_keys_list = [ 233 | "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", 234 | "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", 235 | "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", 236 | "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", 237 | "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", 238 | "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", 239 | "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", 240 | "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", 241 | "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", 242 | "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", 243 | "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", 244 | "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", 245 | "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", 246 | "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", 247 | "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", 248 | "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", 249 | "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", 250 | "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", 251 | "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", 252 | "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", 253 | "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", 254 | "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", 255 | "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", 256 | "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", 257 | "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", 258 | "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", 259 | "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", 260 | "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", 261 | "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", 262 | "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", 263 | "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", 264 | "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", 265 | ] 266 | 267 | 268 | ip_adapter_keys_list = [ 269 | "1.to_k_ip.weight", 270 | "1.to_v_ip.weight", 271 | "3.to_k_ip.weight", 272 | "3.to_v_ip.weight", 273 | "5.to_k_ip.weight", 274 | "5.to_v_ip.weight", 275 | "7.to_k_ip.weight", 276 | "7.to_v_ip.weight", 277 | "9.to_k_ip.weight", 278 | "9.to_v_ip.weight", 279 | "11.to_k_ip.weight", 280 | "11.to_v_ip.weight", 281 | "13.to_k_ip.weight", 282 | "13.to_v_ip.weight", 283 | "15.to_k_ip.weight", 284 | "15.to_v_ip.weight", 285 | "17.to_k_ip.weight", 286 | "17.to_v_ip.weight", 287 | "19.to_k_ip.weight", 288 | "19.to_v_ip.weight", 289 | "21.to_k_ip.weight", 290 | "21.to_v_ip.weight", 291 | "23.to_k_ip.weight", 292 | "23.to_v_ip.weight", 293 | "25.to_k_ip.weight", 294 | "25.to_v_ip.weight", 295 | "27.to_k_ip.weight", 296 | "27.to_v_ip.weight", 297 | "29.to_k_ip.weight", 298 | "29.to_v_ip.weight", 299 | "31.to_k_ip.weight", 300 | "31.to_v_ip.weight", 301 | ] 302 | 303 | UNET2IPAadapter_Keys_MAPIING = { 304 | k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) 305 | } 306 | 307 | 308 | def update_unet_ip_adapter_cross_attn_param( 309 | unet: UNet3DConditionModel, ip_adapter_state_dict: Dict 310 | ) -> None: 311 | """use independent ip_adapter attn 中的 to_k, to_v in unet 312 | ip_adapter: dict whose keys are ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight'] 313 | 314 | 315 | Args: 316 | unet (UNet3DConditionModel): _description_ 317 | ip_adapter_state_dict (Dict): _description_ 318 | """ 319 | unet_spatial_cross_atnns = unet.spatial_cross_attns[0] 320 | unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns} 321 | for i, (unet_key_more, ip_adapter_key) in enumerate( 322 | UNET2IPAadapter_Keys_MAPIING.items() 323 | ): 324 | ip_adapter_value = ip_adapter_state_dict[ip_adapter_key] 325 | unet_key_more_spit = unet_key_more.split(".") 326 | unet_key = ".".join(unet_key_more_spit[:-3]) 327 | suffix = ".".join(unet_key_more_spit[-3:]) 328 | logger.debug( 329 | f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}", 330 | ) 331 | if "to_k" in suffix: 332 | with torch.no_grad(): 333 | unet_spatial_cross_atnns_dct[unet_key].to_k_ip.weight.copy_( 334 | ip_adapter_value.data 335 | ) 336 | else: 337 | with torch.no_grad(): 338 | unet_spatial_cross_atnns_dct[unet_key].to_v_ip.weight.copy_( 339 | ip_adapter_value.data 340 | ) 341 | -------------------------------------------------------------------------------- /musev/models/referencenet_loader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Callable, Dict, Iterable, Union 3 | import PIL 4 | import cv2 5 | import torch 6 | import argparse 7 | import datetime 8 | import logging 9 | import inspect 10 | import math 11 | import os 12 | import shutil 13 | from typing import Dict, List, Optional, Tuple 14 | from pprint import pprint 15 | from collections import OrderedDict 16 | from dataclasses import dataclass 17 | import gc 18 | import time 19 | 20 | import numpy as np 21 | from omegaconf import OmegaConf 22 | from omegaconf import SCMode 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | from einops import rearrange, repeat 28 | import pandas as pd 29 | import h5py 30 | from diffusers.models.modeling_utils import load_state_dict 31 | from diffusers.utils import ( 32 | logging, 33 | ) 34 | from diffusers.utils.import_utils import is_xformers_available 35 | 36 | from .referencenet import ReferenceNet2D 37 | from .unet_loader import update_unet_with_sd 38 | 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | def load_referencenet( 44 | sd_referencenet_model: Tuple[str, nn.Module], 45 | sd_model: nn.Module = None, 46 | need_self_attn_block_embs: bool = False, 47 | need_block_embs: bool = False, 48 | dtype: torch.dtype = torch.float16, 49 | cross_attention_dim: int = 768, 50 | subfolder: str = "unet", 51 | ): 52 | """ 53 | Loads the ReferenceNet model. 54 | 55 | Args: 56 | sd_referencenet_model (Tuple[str, nn.Module] or str): The pretrained ReferenceNet model or the path to the model. 57 | sd_model (nn.Module, optional): The sd_model to update the ReferenceNet with. Defaults to None. 58 | need_self_attn_block_embs (bool, optional): Whether to compute self-attention block embeddings. Defaults to False. 59 | need_block_embs (bool, optional): Whether to compute block embeddings. Defaults to False. 60 | dtype (torch.dtype, optional): The data type of the tensors. Defaults to torch.float16. 61 | cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 768. 62 | subfolder (str, optional): The subfolder of the model. Defaults to "unet". 63 | 64 | Returns: 65 | nn.Module: The loaded ReferenceNet model. 66 | """ 67 | 68 | if isinstance(sd_referencenet_model, str): 69 | referencenet = ReferenceNet2D.from_pretrained( 70 | sd_referencenet_model, 71 | subfolder=subfolder, 72 | need_self_attn_block_embs=need_self_attn_block_embs, 73 | need_block_embs=need_block_embs, 74 | torch_dtype=dtype, 75 | cross_attention_dim=cross_attention_dim, 76 | ) 77 | elif isinstance(sd_referencenet_model, nn.Module): 78 | referencenet = sd_referencenet_model 79 | if sd_model is not None: 80 | referencenet = update_unet_with_sd(referencenet, sd_model) 81 | return referencenet 82 | 83 | 84 | def load_referencenet_by_name( 85 | model_name: str, 86 | sd_referencenet_model: Tuple[str, nn.Module], 87 | sd_model: nn.Module = None, 88 | cross_attention_dim: int = 768, 89 | dtype: torch.dtype = torch.float16, 90 | ) -> nn.Module: 91 | """通过模型名字 初始化 referencenet,载入预训练参数, 92 | 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义 93 | init referencenet with model_name. 94 | if you want to use pretrained model with simple name, you need to define it here. 95 | Args: 96 | model_name (str): _description_ 97 | sd_unet_model (Tuple[str, nn.Module]): _description_ 98 | sd_model (Tuple[str, nn.Module]): _description_ 99 | cross_attention_dim (int, optional): _description_. Defaults to 768. 100 | dtype (torch.dtype, optional): _description_. Defaults to torch.float16. 101 | 102 | Raises: 103 | ValueError: _description_ 104 | 105 | Returns: 106 | nn.Module: _description_ 107 | """ 108 | if model_name in [ 109 | "musev_referencenet", 110 | ]: 111 | unet = load_referencenet( 112 | sd_referencenet_model=sd_referencenet_model, 113 | sd_model=sd_model, 114 | cross_attention_dim=cross_attention_dim, 115 | dtype=dtype, 116 | need_self_attn_block_embs=False, 117 | need_block_embs=True, 118 | subfolder="referencenet", 119 | ) 120 | else: 121 | raise ValueError( 122 | f"unsupport model_name={model_name}, only support ReferenceNet_V0_block13, ReferenceNet_V1_block13, ReferenceNet_V2_block13, ReferenceNet_V0_sefattn16" 123 | ) 124 | return unet 125 | -------------------------------------------------------------------------------- /musev/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py 17 | from __future__ import annotations 18 | 19 | from functools import partial 20 | from typing import Optional 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | from einops import rearrange, repeat 26 | 27 | from diffusers.models.resnet import TemporalConvLayer as DiffusersTemporalConvLayer 28 | from ..data.data_util import batch_index_fill, batch_index_select 29 | from . import Model_Register 30 | 31 | 32 | @Model_Register.register 33 | class TemporalConvLayer(nn.Module): 34 | """ 35 | Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: 36 | https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 37 | """ 38 | 39 | def __init__( 40 | self, 41 | in_dim, 42 | out_dim=None, 43 | dropout=0.0, 44 | keep_content_condition: bool = False, 45 | femb_channels: Optional[int] = None, 46 | need_temporal_weight: bool = True, 47 | ): 48 | super().__init__() 49 | out_dim = out_dim or in_dim 50 | self.in_dim = in_dim 51 | self.out_dim = out_dim 52 | self.keep_content_condition = keep_content_condition 53 | self.femb_channels = femb_channels 54 | self.need_temporal_weight = need_temporal_weight 55 | # conv layers 56 | self.conv1 = nn.Sequential( 57 | nn.GroupNorm(32, in_dim), 58 | nn.SiLU(), 59 | nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)), 60 | ) 61 | self.conv2 = nn.Sequential( 62 | nn.GroupNorm(32, out_dim), 63 | nn.SiLU(), 64 | nn.Dropout(dropout), 65 | nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), 66 | ) 67 | self.conv3 = nn.Sequential( 68 | nn.GroupNorm(32, out_dim), 69 | nn.SiLU(), 70 | nn.Dropout(dropout), 71 | nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), 72 | ) 73 | self.conv4 = nn.Sequential( 74 | nn.GroupNorm(32, out_dim), 75 | nn.SiLU(), 76 | nn.Dropout(dropout), 77 | nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), 78 | ) 79 | 80 | # zero out the last layer params,so the conv block is identity 81 | # nn.init.zeros_(self.conv4[-1].weight) 82 | # nn.init.zeros_(self.conv4[-1].bias) 83 | self.temporal_weight = nn.Parameter( 84 | torch.tensor( 85 | [ 86 | 1e-5, 87 | ] 88 | ) 89 | ) # initialize parameter with 0 90 | # zero out the last layer params,so the conv block is identity 91 | nn.init.zeros_(self.conv4[-1].weight) 92 | nn.init.zeros_(self.conv4[-1].bias) 93 | self.skip_temporal_layers = False # Whether to skip temporal layer 94 | 95 | def forward( 96 | self, 97 | hidden_states, 98 | num_frames=1, 99 | sample_index: torch.LongTensor = None, 100 | vision_conditon_frames_sample_index: torch.LongTensor = None, 101 | femb: torch.Tensor = None, 102 | ): 103 | if self.skip_temporal_layers is True: 104 | return hidden_states 105 | hidden_states_dtype = hidden_states.dtype 106 | hidden_states = rearrange( 107 | hidden_states, "(b t) c h w -> b c t h w", t=num_frames 108 | ) 109 | identity = hidden_states 110 | hidden_states = self.conv1(hidden_states) 111 | hidden_states = self.conv2(hidden_states) 112 | hidden_states = self.conv3(hidden_states) 113 | hidden_states = self.conv4(hidden_states) 114 | # 保留condition对应的frames,便于保持前序内容帧,提升一致性 115 | if self.keep_content_condition: 116 | mask = torch.ones_like(hidden_states, device=hidden_states.device) 117 | mask = batch_index_fill( 118 | mask, dim=2, index=vision_conditon_frames_sample_index, value=0 119 | ) 120 | if self.need_temporal_weight: 121 | hidden_states = ( 122 | identity + torch.abs(self.temporal_weight) * mask * hidden_states 123 | ) 124 | else: 125 | hidden_states = identity + mask * hidden_states 126 | else: 127 | if self.need_temporal_weight: 128 | hidden_states = ( 129 | identity + torch.abs(self.temporal_weight) * hidden_states 130 | ) 131 | else: 132 | hidden_states = identity + hidden_states 133 | hidden_states = rearrange(hidden_states, " b c t h w -> (b t) c h w") 134 | hidden_states = hidden_states.to(dtype=hidden_states_dtype) 135 | return hidden_states 136 | -------------------------------------------------------------------------------- /musev/models/super_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from typing import Any, Dict, Tuple, Union, Optional 6 | from einops import rearrange, repeat 7 | from torch import nn 8 | import torch 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.models.modeling_utils import ModelMixin, load_state_dict 12 | 13 | from ..data.data_util import align_repeat_tensor_single_dim 14 | 15 | from .unet_3d_condition import UNet3DConditionModel 16 | from .referencenet import ReferenceNet2D 17 | from ip_adapter.ip_adapter import ImageProjModel 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class SuperUNet3DConditionModel(nn.Module): 23 | """封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。 24 | 主要作用 25 | 1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些; 26 | 2. 便于 accelerator 的分布式训练; 27 | 28 | wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj 29 | 1. support controlnet, referencenet, etc. 30 | 2. support accelerator distributed training 31 | """ 32 | 33 | _supports_gradient_checkpointing = True 34 | print_idx = 0 35 | 36 | # @register_to_config 37 | def __init__( 38 | self, 39 | unet: nn.Module, 40 | referencenet: nn.Module = None, 41 | controlnet: nn.Module = None, 42 | vae: nn.Module = None, 43 | text_encoder: nn.Module = None, 44 | tokenizer: nn.Module = None, 45 | text_emb_extractor: nn.Module = None, 46 | clip_vision_extractor: nn.Module = None, 47 | ip_adapter_image_proj: nn.Module = None, 48 | ) -> None: 49 | """_summary_ 50 | 51 | Args: 52 | unet (nn.Module): _description_ 53 | referencenet (nn.Module, optional): _description_. Defaults to None. 54 | controlnet (nn.Module, optional): _description_. Defaults to None. 55 | vae (nn.Module, optional): _description_. Defaults to None. 56 | text_encoder (nn.Module, optional): _description_. Defaults to None. 57 | tokenizer (nn.Module, optional): _description_. Defaults to None. 58 | text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None. 59 | clip_vision_extractor (nn.Module, optional): _description_. Defaults to None. 60 | """ 61 | super().__init__() 62 | self.unet = unet 63 | self.referencenet = referencenet 64 | self.controlnet = controlnet 65 | self.vae = vae 66 | self.text_encoder = text_encoder 67 | self.tokenizer = tokenizer 68 | self.text_emb_extractor = text_emb_extractor 69 | self.clip_vision_extractor = clip_vision_extractor 70 | self.ip_adapter_image_proj = ip_adapter_image_proj 71 | 72 | def forward( 73 | self, 74 | unet_params: Dict, 75 | encoder_hidden_states: torch.Tensor, 76 | referencenet_params: Dict = None, 77 | controlnet_params: Dict = None, 78 | controlnet_scale: float = 1.0, 79 | vision_clip_emb: Union[torch.Tensor, None] = None, 80 | prompt_only_use_image_prompt: bool = False, 81 | ): 82 | """_summary_ 83 | 84 | Args: 85 | unet_params (Dict): _description_ 86 | encoder_hidden_states (torch.Tensor): b t n d 87 | referencenet_params (Dict, optional): _description_. Defaults to None. 88 | controlnet_params (Dict, optional): _description_. Defaults to None. 89 | controlnet_scale (float, optional): _description_. Defaults to 1.0. 90 | vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None. 91 | prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False. 92 | 93 | Returns: 94 | _type_: _description_ 95 | """ 96 | batch_size = unet_params["sample"].shape[0] 97 | time_size = unet_params["sample"].shape[2] 98 | 99 | # ip_adapter_cross_attn, prepare image prompt 100 | if vision_clip_emb is not None: 101 | # b t n d -> b t n d 102 | if self.print_idx == 0: 103 | logger.debug( 104 | f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" 105 | ) 106 | if vision_clip_emb.ndim == 3: 107 | vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d") 108 | if self.ip_adapter_image_proj is not None: 109 | vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d") 110 | vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb) 111 | if self.print_idx == 0: 112 | logger.debug( 113 | f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" 114 | ) 115 | if vision_clip_emb.ndim == 2: 116 | vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d") 117 | vision_clip_emb = rearrange( 118 | vision_clip_emb, "(b t) n d -> b t n d", b=batch_size 119 | ) 120 | vision_clip_emb = align_repeat_tensor_single_dim( 121 | vision_clip_emb, target_length=time_size, dim=1 122 | ) 123 | if self.print_idx == 0: 124 | logger.debug( 125 | f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" 126 | ) 127 | 128 | if vision_clip_emb is None and encoder_hidden_states is not None: 129 | vision_clip_emb = encoder_hidden_states 130 | if vision_clip_emb is not None and encoder_hidden_states is None: 131 | encoder_hidden_states = vision_clip_emb 132 | # 当 prompt_only_use_image_prompt 为True时, 133 | # 1. referencenet 都使用 vision_clip_emb 134 | # 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新 135 | # 3. controlnet 当前使用 text_prompt 136 | 137 | # when prompt_only_use_image_prompt True, 138 | # 1. referencenet use vision_clip_emb 139 | # 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update 140 | # 3. controlnet use text_prompt 141 | 142 | # extract referencenet emb 143 | if self.referencenet is not None and referencenet_params is not None: 144 | referencenet_encoder_hidden_states = align_repeat_tensor_single_dim( 145 | vision_clip_emb, 146 | target_length=referencenet_params["num_frames"], 147 | dim=1, 148 | ) 149 | referencenet_params["encoder_hidden_states"] = rearrange( 150 | referencenet_encoder_hidden_states, "b t n d->(b t) n d" 151 | ) 152 | referencenet_out = self.referencenet(**referencenet_params) 153 | ( 154 | down_block_refer_embs, 155 | mid_block_refer_emb, 156 | refer_self_attn_emb, 157 | ) = referencenet_out 158 | if down_block_refer_embs is not None: 159 | if self.print_idx == 0: 160 | logger.debug( 161 | f"len(down_block_refer_embs)={len(down_block_refer_embs)}" 162 | ) 163 | for i, down_emb in enumerate(down_block_refer_embs): 164 | if self.print_idx == 0: 165 | logger.debug( 166 | f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}" 167 | ) 168 | else: 169 | if self.print_idx == 0: 170 | logger.debug(f"down_block_refer_embs is None") 171 | if mid_block_refer_emb is not None: 172 | if self.print_idx == 0: 173 | logger.debug( 174 | f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}" 175 | ) 176 | else: 177 | if self.print_idx == 0: 178 | logger.debug(f"mid_block_refer_emb is None") 179 | if refer_self_attn_emb is not None: 180 | if self.print_idx == 0: 181 | logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}") 182 | for i, self_attn_emb in enumerate(refer_self_attn_emb): 183 | if self.print_idx == 0: 184 | logger.debug( 185 | f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}" 186 | ) 187 | else: 188 | if self.print_idx == 0: 189 | logger.debug(f"refer_self_attn_emb is None") 190 | else: 191 | down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = ( 192 | None, 193 | None, 194 | None, 195 | ) 196 | 197 | # extract controlnet emb 198 | if self.controlnet is not None and controlnet_params is not None: 199 | controlnet_encoder_hidden_states = align_repeat_tensor_single_dim( 200 | encoder_hidden_states, 201 | target_length=unet_params["sample"].shape[2], 202 | dim=1, 203 | ) 204 | controlnet_params["encoder_hidden_states"] = rearrange( 205 | controlnet_encoder_hidden_states, " b t n d -> (b t) n d" 206 | ) 207 | ( 208 | down_block_additional_residuals, 209 | mid_block_additional_residual, 210 | ) = self.controlnet(**controlnet_params) 211 | if controlnet_scale != 1.0: 212 | down_block_additional_residuals = [ 213 | x * controlnet_scale for x in down_block_additional_residuals 214 | ] 215 | mid_block_additional_residual = ( 216 | mid_block_additional_residual * controlnet_scale 217 | ) 218 | for i, down_block_additional_residual in enumerate( 219 | down_block_additional_residuals 220 | ): 221 | if self.print_idx == 0: 222 | logger.debug( 223 | f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}" 224 | ) 225 | 226 | if self.print_idx == 0: 227 | logger.debug( 228 | f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}" 229 | ) 230 | else: 231 | down_block_additional_residuals = None 232 | mid_block_additional_residual = None 233 | 234 | if prompt_only_use_image_prompt and vision_clip_emb is not None: 235 | encoder_hidden_states = vision_clip_emb 236 | 237 | # run unet 238 | out = self.unet( 239 | **unet_params, 240 | down_block_refer_embs=down_block_refer_embs, 241 | mid_block_refer_emb=mid_block_refer_emb, 242 | refer_self_attn_emb=refer_self_attn_emb, 243 | down_block_additional_residuals=down_block_additional_residuals, 244 | mid_block_additional_residual=mid_block_additional_residual, 245 | encoder_hidden_states=encoder_hidden_states, 246 | vision_clip_emb=vision_clip_emb, 247 | ) 248 | self.print_idx += 1 249 | return out 250 | 251 | def _set_gradient_checkpointing(self, module, value=False): 252 | if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)): 253 | module.gradient_checkpointing = value 254 | -------------------------------------------------------------------------------- /musev/models/temporal_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/transformer_temporal.py 16 | from __future__ import annotations 17 | from copy import deepcopy 18 | from dataclasses import dataclass 19 | from typing import List, Literal, Optional 20 | import logging 21 | 22 | import torch 23 | from torch import nn 24 | from einops import rearrange, repeat 25 | 26 | from diffusers.configuration_utils import ConfigMixin, register_to_config 27 | from diffusers.utils import BaseOutput 28 | from diffusers.models.modeling_utils import ModelMixin 29 | from diffusers.models.transformer_temporal import ( 30 | TransformerTemporalModelOutput, 31 | TransformerTemporalModel as DiffusersTransformerTemporalModel, 32 | ) 33 | from diffusers.models.attention_processor import AttnProcessor 34 | 35 | from mmcm.utils.gpu_util import get_gpu_status 36 | from ..data.data_util import ( 37 | batch_concat_two_tensor_with_index, 38 | batch_index_fill, 39 | batch_index_select, 40 | concat_two_tensor, 41 | align_repeat_tensor_single_dim, 42 | ) 43 | from ..utils.attention_util import generate_sparse_causcal_attn_mask 44 | from .attention import BasicTransformerBlock 45 | from .attention_processor import ( 46 | BaseIPAttnProcessor, 47 | ) 48 | from . import Model_Register 49 | 50 | # https://github.com/facebookresearch/xformers/issues/845 51 | # 输入bs*n_frames*w*h太高,xformers报错。因此将transformer_temporal的allow_xformers均关掉 52 | # if bs*n_frames*w*h to large, xformers will raise error. So we close the allow_xformers in transformer_temporal 53 | logger = logging.getLogger(__name__) 54 | 55 | 56 | @Model_Register.register 57 | class TransformerTemporalModel(ModelMixin, ConfigMixin): 58 | """ 59 | Transformer model for video-like data. 60 | 61 | Parameters: 62 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 63 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 64 | in_channels (`int`, *optional*): 65 | Pass if the input is continuous. The number of channels in the input and output. 66 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 67 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 68 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 69 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 70 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 71 | `ImagePositionalEmbeddings`. 72 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 73 | attention_bias (`bool`, *optional*): 74 | Configure if the TransformerBlocks' attention should contain a bias parameter. 75 | double_self_attention (`bool`, *optional*): 76 | Configure if each TransformerBlock should contain two self-attention layers 77 | """ 78 | 79 | @register_to_config 80 | def __init__( 81 | self, 82 | num_attention_heads: int = 16, 83 | attention_head_dim: int = 88, 84 | in_channels: Optional[int] = None, 85 | out_channels: Optional[int] = None, 86 | num_layers: int = 1, 87 | femb_channels: Optional[int] = None, 88 | dropout: float = 0.0, 89 | norm_num_groups: int = 32, 90 | cross_attention_dim: Optional[int] = None, 91 | attention_bias: bool = False, 92 | sample_size: Optional[int] = None, 93 | activation_fn: str = "geglu", 94 | norm_elementwise_affine: bool = True, 95 | double_self_attention: bool = True, 96 | allow_xformers: bool = False, 97 | only_cross_attention: bool = False, 98 | keep_content_condition: bool = False, 99 | need_spatial_position_emb: bool = False, 100 | need_temporal_weight: bool = True, 101 | self_attn_mask: str = None, 102 | # TODO: 运行参数,有待改到forward里面去 103 | # TODO: running parameters, need to be moved to forward 104 | image_scale: float = 1.0, 105 | processor: AttnProcessor | None = None, 106 | remove_femb_non_linear: bool = False, 107 | ): 108 | super().__init__() 109 | 110 | self.num_attention_heads = num_attention_heads 111 | self.attention_head_dim = attention_head_dim 112 | 113 | inner_dim = num_attention_heads * attention_head_dim 114 | self.inner_dim = inner_dim 115 | self.in_channels = in_channels 116 | 117 | self.norm = torch.nn.GroupNorm( 118 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 119 | ) 120 | 121 | self.proj_in = nn.Linear(in_channels, inner_dim) 122 | 123 | # 2. Define temporal positional embedding 124 | self.frame_emb_proj = torch.nn.Linear(femb_channels, inner_dim) 125 | self.remove_femb_non_linear = remove_femb_non_linear 126 | if not remove_femb_non_linear: 127 | self.nonlinearity = nn.SiLU() 128 | 129 | # spatial_position_emb 使用femb_的参数配置 130 | self.need_spatial_position_emb = need_spatial_position_emb 131 | if need_spatial_position_emb: 132 | self.spatial_position_emb_proj = torch.nn.Linear(femb_channels, inner_dim) 133 | # 3. Define transformers blocks 134 | # TODO: 该实现方式不好,待优化 135 | # TODO: bad implementation, need to be optimized 136 | self.need_ipadapter = False 137 | self.cross_attn_temporal_cond = False 138 | self.allow_xformers = allow_xformers 139 | if processor is not None and isinstance(processor, BaseIPAttnProcessor): 140 | self.cross_attn_temporal_cond = True 141 | self.allow_xformers = False 142 | if "NonParam" not in processor.__class__.__name__: 143 | self.need_ipadapter = True 144 | 145 | self.transformer_blocks = nn.ModuleList( 146 | [ 147 | BasicTransformerBlock( 148 | inner_dim, 149 | num_attention_heads, 150 | attention_head_dim, 151 | dropout=dropout, 152 | cross_attention_dim=cross_attention_dim, 153 | activation_fn=activation_fn, 154 | attention_bias=attention_bias, 155 | double_self_attention=double_self_attention, 156 | norm_elementwise_affine=norm_elementwise_affine, 157 | allow_xformers=allow_xformers, 158 | only_cross_attention=only_cross_attention, 159 | cross_attn_temporal_cond=self.need_ipadapter, 160 | image_scale=image_scale, 161 | processor=processor, 162 | ) 163 | for d in range(num_layers) 164 | ] 165 | ) 166 | 167 | self.proj_out = nn.Linear(inner_dim, in_channels) 168 | 169 | self.need_temporal_weight = need_temporal_weight 170 | if need_temporal_weight: 171 | self.temporal_weight = nn.Parameter( 172 | torch.tensor( 173 | [ 174 | 1e-5, 175 | ] 176 | ) 177 | ) # initialize parameter with 0 178 | self.skip_temporal_layers = False # Whether to skip temporal layer 179 | self.keep_content_condition = keep_content_condition 180 | self.self_attn_mask = self_attn_mask 181 | self.only_cross_attention = only_cross_attention 182 | self.double_self_attention = double_self_attention 183 | self.cross_attention_dim = cross_attention_dim 184 | self.image_scale = image_scale 185 | # zero out the last layer params,so the conv block is identity 186 | nn.init.zeros_(self.proj_out.weight) 187 | nn.init.zeros_(self.proj_out.bias) 188 | 189 | def forward( 190 | self, 191 | hidden_states, 192 | femb, 193 | encoder_hidden_states=None, 194 | timestep=None, 195 | class_labels=None, 196 | num_frames=1, 197 | cross_attention_kwargs=None, 198 | sample_index: torch.LongTensor = None, 199 | vision_conditon_frames_sample_index: torch.LongTensor = None, 200 | spatial_position_emb: torch.Tensor = None, 201 | return_dict: bool = True, 202 | ): 203 | """ 204 | Args: 205 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 206 | When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 207 | hidden_states 208 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 209 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 210 | self-attention. 211 | timestep ( `torch.long`, *optional*): 212 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 213 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 214 | Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels 215 | conditioning. 216 | return_dict (`bool`, *optional*, defaults to `True`): 217 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 218 | 219 | Returns: 220 | [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: 221 | [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. 222 | When returning a tuple, the first element is the sample tensor. 223 | """ 224 | if self.skip_temporal_layers is True: 225 | if not return_dict: 226 | return (hidden_states,) 227 | 228 | return TransformerTemporalModelOutput(sample=hidden_states) 229 | 230 | # 1. Input 231 | batch_frames, channel, height, width = hidden_states.shape 232 | batch_size = batch_frames // num_frames 233 | 234 | hidden_states = rearrange( 235 | hidden_states, "(b t) c h w -> b c t h w", b=batch_size 236 | ) 237 | residual = hidden_states 238 | 239 | hidden_states = self.norm(hidden_states) 240 | 241 | hidden_states = rearrange(hidden_states, "b c t h w -> (b h w) t c") 242 | 243 | hidden_states = self.proj_in(hidden_states) 244 | 245 | # 2 Positional embedding 246 | # adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py#L574 247 | if not self.remove_femb_non_linear: 248 | femb = self.nonlinearity(femb) 249 | femb = self.frame_emb_proj(femb) 250 | femb = align_repeat_tensor_single_dim(femb, hidden_states.shape[0], dim=0) 251 | hidden_states = hidden_states + femb 252 | 253 | # 3. Blocks 254 | if ( 255 | (self.only_cross_attention or not self.double_self_attention) 256 | and self.cross_attention_dim is not None 257 | and encoder_hidden_states is not None 258 | ): 259 | encoder_hidden_states = align_repeat_tensor_single_dim( 260 | encoder_hidden_states, 261 | hidden_states.shape[0], 262 | dim=0, 263 | n_src_base_length=batch_size, 264 | ) 265 | 266 | for i, block in enumerate(self.transformer_blocks): 267 | hidden_states = block( 268 | hidden_states, 269 | encoder_hidden_states=encoder_hidden_states, 270 | timestep=timestep, 271 | cross_attention_kwargs=cross_attention_kwargs, 272 | class_labels=class_labels, 273 | ) 274 | 275 | # 4. Output 276 | hidden_states = self.proj_out(hidden_states) 277 | hidden_states = rearrange( 278 | hidden_states, "(b h w) t c -> b c t h w", b=batch_size, h=height, w=width 279 | ).contiguous() 280 | 281 | # 保留condition对应的frames,便于保持前序内容帧,提升一致性 282 | # keep the frames corresponding to the condition to maintain the previous content frames and improve consistency 283 | if ( 284 | vision_conditon_frames_sample_index is not None 285 | and self.keep_content_condition 286 | ): 287 | mask = torch.ones_like(hidden_states, device=hidden_states.device) 288 | mask = batch_index_fill( 289 | mask, dim=2, index=vision_conditon_frames_sample_index, value=0 290 | ) 291 | if self.need_temporal_weight: 292 | output = ( 293 | residual + torch.abs(self.temporal_weight) * mask * hidden_states 294 | ) 295 | else: 296 | output = residual + mask * hidden_states 297 | else: 298 | if self.need_temporal_weight: 299 | output = residual + torch.abs(self.temporal_weight) * hidden_states 300 | else: 301 | output = residual + mask * hidden_states 302 | 303 | # output = torch.abs(self.temporal_weight) * hidden_states + residual 304 | output = rearrange(output, "b c t h w -> (b t) c h w") 305 | if not return_dict: 306 | return (output,) 307 | 308 | return TransformerTemporalModelOutput(sample=output) 309 | -------------------------------------------------------------------------------- /musev/models/text_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | from torch import nn 3 | 4 | 5 | class TextEmbExtractor(nn.Module): 6 | def __init__(self, tokenizer, text_encoder) -> None: 7 | super(TextEmbExtractor, self).__init__() 8 | self.tokenizer = tokenizer 9 | self.text_encoder = text_encoder 10 | 11 | def forward( 12 | self, 13 | texts, 14 | text_params: Dict = None, 15 | ): 16 | if text_params is None: 17 | text_params = {} 18 | special_prompt_input = self.tokenizer( 19 | texts, 20 | max_length=self.tokenizer.model_max_length, 21 | padding="max_length", 22 | truncation=True, 23 | return_tensors="pt", 24 | ) 25 | if ( 26 | hasattr(self.text_encoder.config, "use_attention_mask") 27 | and self.text_encoder.config.use_attention_mask 28 | ): 29 | attention_mask = special_prompt_input.attention_mask.to( 30 | self.text_encoder.device 31 | ) 32 | else: 33 | attention_mask = None 34 | 35 | embeddings = self.text_encoder( 36 | special_prompt_input.input_ids.to(self.text_encoder.device), 37 | attention_mask=attention_mask, 38 | **text_params 39 | ) 40 | return embeddings 41 | -------------------------------------------------------------------------------- /musev/models/unet_loader.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Callable, Dict, Iterable, Union 3 | import PIL 4 | import cv2 5 | import torch 6 | import argparse 7 | import datetime 8 | import logging 9 | import inspect 10 | import math 11 | import os 12 | import shutil 13 | from typing import Dict, List, Optional, Tuple 14 | from pprint import pprint 15 | from collections import OrderedDict 16 | from dataclasses import dataclass 17 | import gc 18 | import time 19 | 20 | import numpy as np 21 | from omegaconf import OmegaConf 22 | from omegaconf import SCMode 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | from einops import rearrange, repeat 28 | import pandas as pd 29 | import h5py 30 | from diffusers.models.modeling_utils import load_state_dict 31 | from diffusers.utils import ( 32 | logging, 33 | ) 34 | from diffusers.utils.import_utils import is_xformers_available 35 | 36 | from ..models.unet_3d_condition import UNet3DConditionModel 37 | 38 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 39 | 40 | 41 | def update_unet_with_sd( 42 | unet: nn.Module, sd_model: Tuple[str, nn.Module], subfolder: str = "unet" 43 | ): 44 | """更新T2V模型中的T2I参数. update t2i parameters in t2v model 45 | 46 | Args: 47 | unet (nn.Module): _description_ 48 | sd_model (Tuple[str, nn.Module]): _description_ 49 | 50 | Returns: 51 | _type_: _description_ 52 | """ 53 | # dtype = unet.dtype 54 | # TODO: in this way, sd_model_path must be absolute path, to be more dynamic 55 | if isinstance(sd_model, str): 56 | if os.path.isdir(sd_model): 57 | unet_state_dict = load_state_dict( 58 | os.path.join(sd_model, subfolder, "diffusion_pytorch_model.bin"), 59 | ) 60 | elif os.path.isfile(sd_model): 61 | if sd_model.endswith("pth"): 62 | unet_state_dict = torch.load(sd_model, map_location="cpu") 63 | print(f"referencenet successful load ={sd_model} with torch.load") 64 | else: 65 | try: 66 | unet_state_dict = load_state_dict(sd_model) 67 | print( 68 | f"referencenet successful load with {sd_model} with load_state_dict" 69 | ) 70 | except Exception as e: 71 | print(e) 72 | 73 | elif isinstance(sd_model, nn.Module): 74 | unet_state_dict = sd_model.state_dict() 75 | else: 76 | raise ValueError(f"given {type(sd_model)}, but only support nn.Module or str") 77 | missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) 78 | assert len(unexpected) == 0, f"unet load_state_dict error, unexpected={unexpected}" 79 | # unet.to(dtype=dtype) 80 | return unet 81 | 82 | 83 | def load_unet( 84 | sd_unet_model: Tuple[str, nn.Module], 85 | sd_model: Tuple[str, nn.Module] = None, 86 | cross_attention_dim: int = 768, 87 | temporal_transformer: str = "TransformerTemporalModel", 88 | temporal_conv_block: str = "TemporalConvLayer", 89 | need_spatial_position_emb: bool = False, 90 | need_transformer_in: bool = True, 91 | need_t2i_ip_adapter: bool = False, 92 | need_adain_temporal_cond: bool = False, 93 | t2i_ip_adapter_attn_processor: str = "IPXFormersAttnProcessor", 94 | keep_vision_condtion: bool = False, 95 | use_anivv1_cfg: bool = False, 96 | resnet_2d_skip_time_act: bool = False, 97 | dtype: torch.dtype = torch.float16, 98 | need_zero_vis_cond_temb: bool = True, 99 | norm_spatial_length: bool = True, 100 | spatial_max_length: int = 2048, 101 | need_refer_emb: bool = False, 102 | ip_adapter_cross_attn=False, 103 | t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor", 104 | need_t2i_facein: bool = False, 105 | need_t2i_ip_adapter_face: bool = False, 106 | strict: bool = True, 107 | ): 108 | """通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. 109 | 该部分都是通过 models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型 110 | model is defined and trained in models.unet_3d_condition.py:UNet3DConditionModel 111 | 112 | Args: 113 | sd_unet_model (Tuple[str, nn.Module]): _description_ 114 | sd_model (Tuple[str, nn.Module]): _description_ 115 | cross_attention_dim (int, optional): _description_. Defaults to 768. 116 | temporal_transformer (str, optional): _description_. Defaults to "TransformerTemporalModel". 117 | temporal_conv_block (str, optional): _description_. Defaults to "TemporalConvLayer". 118 | need_spatial_position_emb (bool, optional): _description_. Defaults to False. 119 | need_transformer_in (bool, optional): _description_. Defaults to True. 120 | need_t2i_ip_adapter (bool, optional): _description_. Defaults to False. 121 | need_adain_temporal_cond (bool, optional): _description_. Defaults to False. 122 | t2i_ip_adapter_attn_processor (str, optional): _description_. Defaults to "IPXFormersAttnProcessor". 123 | keep_vision_condtion (bool, optional): _description_. Defaults to False. 124 | use_anivv1_cfg (bool, optional): _description_. Defaults to False. 125 | resnet_2d_skip_time_act (bool, optional): _description_. Defaults to False. 126 | dtype (torch.dtype, optional): _description_. Defaults to torch.float16. 127 | need_zero_vis_cond_temb (bool, optional): _description_. Defaults to True. 128 | norm_spatial_length (bool, optional): _description_. Defaults to True. 129 | spatial_max_length (int, optional): _description_. Defaults to 2048. 130 | 131 | Returns: 132 | _type_: _description_ 133 | """ 134 | if isinstance(sd_unet_model, str): 135 | unet = UNet3DConditionModel.from_pretrained_2d( 136 | sd_unet_model, 137 | subfolder="unet", 138 | temporal_transformer=temporal_transformer, 139 | temporal_conv_block=temporal_conv_block, 140 | cross_attention_dim=cross_attention_dim, 141 | need_spatial_position_emb=need_spatial_position_emb, 142 | need_transformer_in=need_transformer_in, 143 | need_t2i_ip_adapter=need_t2i_ip_adapter, 144 | need_adain_temporal_cond=need_adain_temporal_cond, 145 | t2i_ip_adapter_attn_processor=t2i_ip_adapter_attn_processor, 146 | keep_vision_condtion=keep_vision_condtion, 147 | use_anivv1_cfg=use_anivv1_cfg, 148 | resnet_2d_skip_time_act=resnet_2d_skip_time_act, 149 | torch_dtype=dtype, 150 | need_zero_vis_cond_temb=need_zero_vis_cond_temb, 151 | norm_spatial_length=norm_spatial_length, 152 | spatial_max_length=spatial_max_length, 153 | need_refer_emb=need_refer_emb, 154 | ip_adapter_cross_attn=ip_adapter_cross_attn, 155 | t2i_crossattn_ip_adapter_attn_processor=t2i_crossattn_ip_adapter_attn_processor, 156 | need_t2i_facein=need_t2i_facein, 157 | strict=strict, 158 | need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, 159 | ) 160 | elif isinstance(sd_unet_model, nn.Module): 161 | unet = sd_unet_model 162 | if sd_model is not None: 163 | unet = update_unet_with_sd(unet, sd_model) 164 | return unet 165 | 166 | 167 | def load_unet_custom_unet( 168 | sd_unet_model: Tuple[str, nn.Module], 169 | sd_model: Tuple[str, nn.Module], 170 | unet_class: nn.Module, 171 | ): 172 | """ 173 | 通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. 174 | 该部分都是通过 不通过models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型 175 | model is not defined in models.unet_3d_condition.py:UNet3DConditionModel 176 | Args: 177 | sd_unet_model (Tuple[str, nn.Module]): _description_ 178 | sd_model (Tuple[str, nn.Module]): _description_ 179 | unet_class (nn.Module): _description_ 180 | 181 | Returns: 182 | _type_: _description_ 183 | """ 184 | if isinstance(sd_unet_model, str): 185 | unet = unet_class.from_pretrained( 186 | sd_unet_model, 187 | subfolder="unet", 188 | ) 189 | elif isinstance(sd_unet_model, nn.Module): 190 | unet = sd_unet_model 191 | 192 | # TODO: in this way, sd_model_path must be absolute path, to be more dynamic 193 | if isinstance(sd_model, str): 194 | unet_state_dict = load_state_dict( 195 | os.path.join(sd_model, "unet/diffusion_pytorch_model.bin"), 196 | ) 197 | elif isinstance(sd_model, nn.Module): 198 | unet_state_dict = sd_model.state_dict() 199 | missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) 200 | assert ( 201 | len(unexpected) == 0 202 | ), "unet load_state_dict error" # Load scheduler, tokenizer and models. 203 | return unet 204 | 205 | 206 | def load_unet_by_name( 207 | model_name: str, 208 | sd_unet_model: Tuple[str, nn.Module], 209 | sd_model: Tuple[str, nn.Module] = None, 210 | cross_attention_dim: int = 768, 211 | dtype: torch.dtype = torch.float16, 212 | need_t2i_facein: bool = False, 213 | need_t2i_ip_adapter_face: bool = False, 214 | strict: bool = True, 215 | ) -> nn.Module: 216 | """通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. 217 | 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义 218 | if you want to use pretrained model with simple name, you need to define it here. 219 | Args: 220 | model_name (str): _description_ 221 | sd_unet_model (Tuple[str, nn.Module]): _description_ 222 | sd_model (Tuple[str, nn.Module]): _description_ 223 | cross_attention_dim (int, optional): _description_. Defaults to 768. 224 | dtype (torch.dtype, optional): _description_. Defaults to torch.float16. 225 | 226 | Raises: 227 | ValueError: _description_ 228 | 229 | Returns: 230 | nn.Module: _description_ 231 | """ 232 | if model_name in ["musev"]: 233 | unet = load_unet( 234 | sd_unet_model=sd_unet_model, 235 | sd_model=sd_model, 236 | need_spatial_position_emb=False, 237 | cross_attention_dim=cross_attention_dim, 238 | need_t2i_ip_adapter=True, 239 | need_adain_temporal_cond=True, 240 | t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor", 241 | dtype=dtype, 242 | ) 243 | elif model_name in [ 244 | "musev_referencenet", 245 | "musev_referencenet_pose", 246 | ]: 247 | unet = load_unet( 248 | sd_unet_model=sd_unet_model, 249 | sd_model=sd_model, 250 | cross_attention_dim=cross_attention_dim, 251 | temporal_conv_block="TemporalConvLayer", 252 | need_transformer_in=False, 253 | temporal_transformer="TransformerTemporalModel", 254 | use_anivv1_cfg=True, 255 | resnet_2d_skip_time_act=True, 256 | need_t2i_ip_adapter=True, 257 | need_adain_temporal_cond=True, 258 | keep_vision_condtion=True, 259 | t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor", 260 | dtype=dtype, 261 | need_refer_emb=True, 262 | need_zero_vis_cond_temb=True, 263 | ip_adapter_cross_attn=True, 264 | t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor", 265 | need_t2i_facein=need_t2i_facein, 266 | strict=strict, 267 | need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, 268 | ) 269 | else: 270 | raise ValueError( 271 | f"unsupport model_name={model_name}, only support musev, musev_referencenet, musev_referencenet_pose" 272 | ) 273 | return unet 274 | -------------------------------------------------------------------------------- /musev/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/musev/pipelines/__init__.py -------------------------------------------------------------------------------- /musev/pipelines/context.py: -------------------------------------------------------------------------------- 1 | # TODO: Adapted from cli 2 | import math 3 | from typing import Callable, List, Optional 4 | 5 | import numpy as np 6 | 7 | from mmcm.utils.itertools_util import generate_sample_idxs 8 | 9 | # copy from https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/context.py 10 | 11 | 12 | def ordered_halving(val): 13 | bin_str = f"{val:064b}" 14 | bin_flip = bin_str[::-1] 15 | as_int = int(bin_flip, 2) 16 | 17 | return as_int / (1 << 64) 18 | 19 | 20 | # TODO: closed_loop not work, to fix it 21 | def uniform( 22 | step: int = ..., 23 | num_steps: Optional[int] = None, 24 | num_frames: int = ..., 25 | context_size: Optional[int] = None, 26 | context_stride: int = 3, 27 | context_overlap: int = 4, 28 | closed_loop: bool = True, 29 | ): 30 | if num_frames <= context_size: 31 | yield list(range(num_frames)) 32 | return 33 | 34 | context_stride = min( 35 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 36 | ) 37 | 38 | for context_step in 1 << np.arange(context_stride): 39 | pad = int(round(num_frames * ordered_halving(step))) 40 | for j in range( 41 | int(ordered_halving(step) * context_step) + pad, 42 | num_frames + pad + (0 if closed_loop else -context_overlap), 43 | (context_size * context_step - context_overlap), 44 | ): 45 | yield [ 46 | e % num_frames 47 | for e in range(j, j + context_size * context_step, context_step) 48 | ] 49 | 50 | 51 | def uniform_v2( 52 | step: int = ..., 53 | num_steps: Optional[int] = None, 54 | num_frames: int = ..., 55 | context_size: Optional[int] = None, 56 | context_stride: int = 3, 57 | context_overlap: int = 4, 58 | closed_loop: bool = True, 59 | ): 60 | return generate_sample_idxs( 61 | total=num_frames, 62 | window_size=context_size, 63 | step=context_size - context_overlap, 64 | sample_rate=1, 65 | drop_last=False, 66 | ) 67 | 68 | 69 | def get_context_scheduler(name: str) -> Callable: 70 | if name == "uniform": 71 | return uniform 72 | elif name == "uniform_v2": 73 | return uniform_v2 74 | else: 75 | raise ValueError(f"Unknown context_overlap policy {name}") 76 | 77 | 78 | def get_total_steps( 79 | scheduler, 80 | timesteps: List[int], 81 | num_steps: Optional[int] = None, 82 | num_frames: int = ..., 83 | context_size: Optional[int] = None, 84 | context_stride: int = 3, 85 | context_overlap: int = 4, 86 | closed_loop: bool = True, 87 | ): 88 | return sum( 89 | len( 90 | list( 91 | scheduler( 92 | i, 93 | num_steps, 94 | num_frames, 95 | context_size, 96 | context_stride, 97 | context_overlap, 98 | ) 99 | ) 100 | ) 101 | for i in range(len(timesteps)) 102 | ) 103 | 104 | 105 | def drop_last_repeat_context(contexts: List[List[int]]) -> List[List[int]]: 106 | """if len(contexts)>=2 and the max value the oenultimate list same as of the last list 107 | 108 | Args: 109 | List (_type_): _description_ 110 | 111 | Returns: 112 | List[List[int]]: _description_ 113 | """ 114 | if len(contexts) >= 2 and contexts[-1][-1] == contexts[-2][-1]: 115 | return contexts[:-1] 116 | else: 117 | return contexts 118 | 119 | 120 | def prepare_global_context( 121 | context_schedule: str, 122 | num_inference_steps: int, 123 | time_size: int, 124 | context_frames: int, 125 | context_stride: int, 126 | context_overlap: int, 127 | context_batch_size: int, 128 | ): 129 | context_scheduler = get_context_scheduler(context_schedule) 130 | context_queue = list( 131 | context_scheduler( 132 | step=0, 133 | num_steps=num_inference_steps, 134 | num_frames=time_size, 135 | context_size=context_frames, 136 | context_stride=context_stride, 137 | context_overlap=context_overlap, 138 | ) 139 | ) 140 | # 如果context_queue的最后一个索引最大值和倒数第二个索引最大值相同,说明最后一个列表就是因为step带来的冗余项,可以去掉 141 | # remove the last context if max index of the last context is the same as the max index of the second last context 142 | context_queue = drop_last_repeat_context(context_queue) 143 | num_context_batches = math.ceil(len(context_queue) / context_batch_size) 144 | global_context = [] 145 | for i_tmp in range(num_context_batches): 146 | global_context.append( 147 | context_queue[i_tmp * context_batch_size : (i_tmp + 1) * context_batch_size] 148 | ) 149 | return global_context 150 | -------------------------------------------------------------------------------- /musev/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler 2 | from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler 3 | from .scheduling_euler_discrete import EulerDiscreteScheduler 4 | from .scheduling_lcm import LCMScheduler 5 | from .scheduling_ddim import DDIMScheduler 6 | from .scheduling_ddpm import DDPMScheduler 7 | -------------------------------------------------------------------------------- /musev/schedulers/scheduling_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim 16 | 17 | from __future__ import annotations 18 | 19 | import math 20 | from dataclasses import dataclass 21 | from typing import List, Optional, Tuple, Union 22 | 23 | import numpy as np 24 | from numpy import ndarray 25 | import torch 26 | 27 | from diffusers.configuration_utils import ConfigMixin, register_to_config 28 | from diffusers.utils import BaseOutput 29 | from diffusers.utils.torch_utils import randn_tensor 30 | from diffusers.schedulers.scheduling_utils import ( 31 | KarrasDiffusionSchedulers, 32 | SchedulerMixin, 33 | ) 34 | from diffusers.schedulers.scheduling_ddpm import ( 35 | DDPMSchedulerOutput, 36 | betas_for_alpha_bar, 37 | DDPMScheduler as DiffusersDDPMScheduler, 38 | ) 39 | from ..utils.noise_util import video_fusion_noise 40 | 41 | 42 | class DDPMScheduler(DiffusersDDPMScheduler): 43 | """ 44 | `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. 45 | 46 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 47 | methods the library implements for all schedulers such as loading and saving. 48 | 49 | Args: 50 | num_train_timesteps (`int`, defaults to 1000): 51 | The number of diffusion steps to train the model. 52 | beta_start (`float`, defaults to 0.0001): 53 | The starting `beta` value of inference. 54 | beta_end (`float`, defaults to 0.02): 55 | The final `beta` value. 56 | beta_schedule (`str`, defaults to `"linear"`): 57 | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 58 | `linear`, `scaled_linear`, or `squaredcos_cap_v2`. 59 | variance_type (`str`, defaults to `"fixed_small"`): 60 | Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, 61 | `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. 62 | clip_sample (`bool`, defaults to `True`): 63 | Clip the predicted sample for numerical stability. 64 | clip_sample_range (`float`, defaults to 1.0): 65 | The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. 66 | prediction_type (`str`, defaults to `epsilon`, *optional*): 67 | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), 68 | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen 69 | Video](https://imagen.research.google/video/paper.pdf) paper). 70 | thresholding (`bool`, defaults to `False`): 71 | Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such 72 | as Stable Diffusion. 73 | dynamic_thresholding_ratio (`float`, defaults to 0.995): 74 | The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. 75 | sample_max_value (`float`, defaults to 1.0): 76 | The threshold value for dynamic thresholding. Valid only when `thresholding=True`. 77 | timestep_spacing (`str`, defaults to `"leading"`): 78 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 79 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 80 | steps_offset (`int`, defaults to 0): 81 | An offset added to the inference steps. You can use a combination of `offset=1` and 82 | `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable 83 | Diffusion. 84 | """ 85 | 86 | _compatibles = [e.name for e in KarrasDiffusionSchedulers] 87 | order = 1 88 | 89 | @register_to_config 90 | def __init__( 91 | self, 92 | num_train_timesteps: int = 1000, 93 | beta_start: float = 0.0001, 94 | beta_end: float = 0.02, 95 | beta_schedule: str = "linear", 96 | trained_betas: ndarray | List[float] | None = None, 97 | variance_type: str = "fixed_small", 98 | clip_sample: bool = True, 99 | prediction_type: str = "epsilon", 100 | thresholding: bool = False, 101 | dynamic_thresholding_ratio: float = 0.995, 102 | clip_sample_range: float = 1, 103 | sample_max_value: float = 1, 104 | timestep_spacing: str = "leading", 105 | steps_offset: int = 0, 106 | ): 107 | super().__init__( 108 | num_train_timesteps, 109 | beta_start, 110 | beta_end, 111 | beta_schedule, 112 | trained_betas, 113 | variance_type, 114 | clip_sample, 115 | prediction_type, 116 | thresholding, 117 | dynamic_thresholding_ratio, 118 | clip_sample_range, 119 | sample_max_value, 120 | timestep_spacing, 121 | steps_offset, 122 | ) 123 | 124 | def step( 125 | self, 126 | model_output: torch.FloatTensor, 127 | timestep: int, 128 | sample: torch.FloatTensor, 129 | generator=None, 130 | return_dict: bool = True, 131 | w_ind_noise: float = 0.5, 132 | noise_type: str = "random", 133 | ) -> Union[DDPMSchedulerOutput, Tuple]: 134 | """ 135 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 136 | process from the learned model outputs (most often the predicted noise). 137 | 138 | Args: 139 | model_output (`torch.FloatTensor`): 140 | The direct output from learned diffusion model. 141 | timestep (`float`): 142 | The current discrete timestep in the diffusion chain. 143 | sample (`torch.FloatTensor`): 144 | A current instance of a sample created by the diffusion process. 145 | generator (`torch.Generator`, *optional*): 146 | A random number generator. 147 | return_dict (`bool`, *optional*, defaults to `True`): 148 | Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. 149 | 150 | Returns: 151 | [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: 152 | If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a 153 | tuple is returned where the first element is the sample tensor. 154 | 155 | """ 156 | t = timestep 157 | 158 | prev_t = self.previous_timestep(t) 159 | 160 | if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ 161 | "learned", 162 | "learned_range", 163 | ]: 164 | model_output, predicted_variance = torch.split( 165 | model_output, sample.shape[1], dim=1 166 | ) 167 | else: 168 | predicted_variance = None 169 | 170 | # 1. compute alphas, betas 171 | alpha_prod_t = self.alphas_cumprod[t] 172 | alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one 173 | beta_prod_t = 1 - alpha_prod_t 174 | beta_prod_t_prev = 1 - alpha_prod_t_prev 175 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev 176 | current_beta_t = 1 - current_alpha_t 177 | 178 | # 2. compute predicted original sample from predicted noise also called 179 | # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf 180 | if self.config.prediction_type == "epsilon": 181 | pred_original_sample = ( 182 | sample - beta_prod_t ** (0.5) * model_output 183 | ) / alpha_prod_t ** (0.5) 184 | elif self.config.prediction_type == "sample": 185 | pred_original_sample = model_output 186 | elif self.config.prediction_type == "v_prediction": 187 | pred_original_sample = (alpha_prod_t**0.5) * sample - ( 188 | beta_prod_t**0.5 189 | ) * model_output 190 | else: 191 | raise ValueError( 192 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" 193 | " `v_prediction` for the DDPMScheduler." 194 | ) 195 | 196 | # 3. Clip or threshold "predicted x_0" 197 | if self.config.thresholding: 198 | pred_original_sample = self._threshold_sample(pred_original_sample) 199 | elif self.config.clip_sample: 200 | pred_original_sample = pred_original_sample.clamp( 201 | -self.config.clip_sample_range, self.config.clip_sample_range 202 | ) 203 | 204 | # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t 205 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 206 | pred_original_sample_coeff = ( 207 | alpha_prod_t_prev ** (0.5) * current_beta_t 208 | ) / beta_prod_t 209 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t 210 | 211 | # 5. Compute predicted previous sample µ_t 212 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 213 | pred_prev_sample = ( 214 | pred_original_sample_coeff * pred_original_sample 215 | + current_sample_coeff * sample 216 | ) 217 | 218 | # 6. Add noise 219 | variance = 0 220 | if t > 0: 221 | device = model_output.device 222 | # if variance_noise is None: 223 | # variance_noise = randn_tensor( 224 | # model_output.shape, 225 | # generator=generator, 226 | # device=model_output.device, 227 | # dtype=model_output.dtype, 228 | # ) 229 | device = model_output.device 230 | 231 | if noise_type == "random": 232 | variance_noise = randn_tensor( 233 | model_output.shape, 234 | dtype=model_output.dtype, 235 | device=device, 236 | generator=generator, 237 | ) 238 | elif noise_type == "video_fusion": 239 | variance_noise = video_fusion_noise( 240 | model_output, w_ind_noise=w_ind_noise, generator=generator 241 | ) 242 | if self.variance_type == "fixed_small_log": 243 | variance = ( 244 | self._get_variance(t, predicted_variance=predicted_variance) 245 | * variance_noise 246 | ) 247 | elif self.variance_type == "learned_range": 248 | variance = self._get_variance(t, predicted_variance=predicted_variance) 249 | variance = torch.exp(0.5 * variance) * variance_noise 250 | else: 251 | variance = ( 252 | self._get_variance(t, predicted_variance=predicted_variance) ** 0.5 253 | ) * variance_noise 254 | 255 | pred_prev_sample = pred_prev_sample + variance 256 | 257 | if not return_dict: 258 | return (pred_prev_sample,) 259 | 260 | return DDPMSchedulerOutput( 261 | prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample 262 | ) 263 | -------------------------------------------------------------------------------- /musev/schedulers/scheduling_euler_discrete.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import logging 3 | 4 | from typing import List, Optional, Tuple, Union 5 | import numpy as np 6 | from numpy import ndarray 7 | import torch 8 | from torch import Generator, FloatTensor 9 | from diffusers.schedulers.scheduling_euler_discrete import ( 10 | EulerDiscreteScheduler as DiffusersEulerDiscreteScheduler, 11 | EulerDiscreteSchedulerOutput, 12 | ) 13 | from diffusers.utils.torch_utils import randn_tensor 14 | 15 | from ..utils.noise_util import video_fusion_noise 16 | 17 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 18 | 19 | 20 | class EulerDiscreteScheduler(DiffusersEulerDiscreteScheduler): 21 | def __init__( 22 | self, 23 | num_train_timesteps: int = 1000, 24 | beta_start: float = 0.0001, 25 | beta_end: float = 0.02, 26 | beta_schedule: str = "linear", 27 | trained_betas: ndarray | List[float] | None = None, 28 | prediction_type: str = "epsilon", 29 | interpolation_type: str = "linear", 30 | use_karras_sigmas: bool | None = False, 31 | timestep_spacing: str = "linspace", 32 | steps_offset: int = 0, 33 | ): 34 | super().__init__( 35 | num_train_timesteps, 36 | beta_start, 37 | beta_end, 38 | beta_schedule, 39 | trained_betas, 40 | prediction_type, 41 | interpolation_type, 42 | use_karras_sigmas, 43 | timestep_spacing, 44 | steps_offset, 45 | ) 46 | 47 | def step( 48 | self, 49 | model_output: torch.FloatTensor, 50 | timestep: Union[float, torch.FloatTensor], 51 | sample: torch.FloatTensor, 52 | s_churn: float = 0.0, 53 | s_tmin: float = 0.0, 54 | s_tmax: float = float("inf"), 55 | s_noise: float = 1.0, 56 | generator: Optional[torch.Generator] = None, 57 | return_dict: bool = True, 58 | w_ind_noise: float = 0.5, 59 | noise_type: str = "random", 60 | ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: 61 | """ 62 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 63 | process from the learned model outputs (most often the predicted noise). 64 | 65 | Args: 66 | model_output (`torch.FloatTensor`): 67 | The direct output from learned diffusion model. 68 | timestep (`float`): 69 | The current discrete timestep in the diffusion chain. 70 | sample (`torch.FloatTensor`): 71 | A current instance of a sample created by the diffusion process. 72 | s_churn (`float`): 73 | s_tmin (`float`): 74 | s_tmax (`float`): 75 | s_noise (`float`, defaults to 1.0): 76 | Scaling factor for noise added to the sample. 77 | generator (`torch.Generator`, *optional*): 78 | A random number generator. 79 | return_dict (`bool`): 80 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 81 | tuple. 82 | 83 | Returns: 84 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 85 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 86 | returned, otherwise a tuple is returned where the first element is the sample tensor. 87 | """ 88 | 89 | if ( 90 | isinstance(timestep, int) 91 | or isinstance(timestep, torch.IntTensor) 92 | or isinstance(timestep, torch.LongTensor) 93 | ): 94 | raise ValueError( 95 | ( 96 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 97 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 98 | " one of the `scheduler.timesteps` as a timestep." 99 | ), 100 | ) 101 | 102 | if not self.is_scale_input_called: 103 | logger.warning( 104 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 105 | "See `StableDiffusionPipeline` for a usage example." 106 | ) 107 | 108 | if self.step_index is None: 109 | self._init_step_index(timestep) 110 | 111 | sigma = self.sigmas[self.step_index] 112 | 113 | gamma = ( 114 | min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) 115 | if s_tmin <= sigma <= s_tmax 116 | else 0.0 117 | ) 118 | device = model_output.device 119 | 120 | if noise_type == "random": 121 | noise = randn_tensor( 122 | model_output.shape, 123 | dtype=model_output.dtype, 124 | device=device, 125 | generator=generator, 126 | ) 127 | elif noise_type == "video_fusion": 128 | noise = video_fusion_noise( 129 | model_output, w_ind_noise=w_ind_noise, generator=generator 130 | ) 131 | 132 | eps = noise * s_noise 133 | sigma_hat = sigma * (gamma + 1) 134 | 135 | if gamma > 0: 136 | sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 137 | 138 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 139 | # NOTE: "original_sample" should not be an expected prediction_type but is left in for 140 | # backwards compatibility 141 | if ( 142 | self.config.prediction_type == "original_sample" 143 | or self.config.prediction_type == "sample" 144 | ): 145 | pred_original_sample = model_output 146 | elif self.config.prediction_type == "epsilon": 147 | pred_original_sample = sample - sigma_hat * model_output 148 | elif self.config.prediction_type == "v_prediction": 149 | # * c_out + input * c_skip 150 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( 151 | sample / (sigma**2 + 1) 152 | ) 153 | else: 154 | raise ValueError( 155 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 156 | ) 157 | 158 | # 2. Convert to an ODE derivative 159 | derivative = (sample - pred_original_sample) / sigma_hat 160 | 161 | dt = self.sigmas[self.step_index + 1] - sigma_hat 162 | 163 | prev_sample = sample + derivative * dt 164 | 165 | # upon completion increase step index by one 166 | self._step_index += 1 167 | 168 | if not return_dict: 169 | return (prev_sample,) 170 | 171 | return EulerDiscreteSchedulerOutput( 172 | prev_sample=prev_sample, pred_original_sample=pred_original_sample 173 | ) 174 | 175 | def step_bk( 176 | self, 177 | model_output: FloatTensor, 178 | timestep: float | FloatTensor, 179 | sample: FloatTensor, 180 | s_churn: float = 0, 181 | s_tmin: float = 0, 182 | s_tmax: float = float("inf"), 183 | s_noise: float = 1, 184 | generator: Generator | None = None, 185 | return_dict: bool = True, 186 | w_ind_noise: float = 0.5, 187 | noise_type: str = "random", 188 | ) -> EulerDiscreteSchedulerOutput | Tuple: 189 | """ 190 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 191 | process from the learned model outputs (most often the predicted noise). 192 | 193 | Args: 194 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 195 | timestep (`float`): current timestep in the diffusion chain. 196 | sample (`torch.FloatTensor`): 197 | current instance of sample being created by diffusion process. 198 | s_churn (`float`) 199 | s_tmin (`float`) 200 | s_tmax (`float`) 201 | s_noise (`float`) 202 | generator (`torch.Generator`, optional): Random number generator. 203 | return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class 204 | 205 | Returns: 206 | [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: 207 | [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a 208 | `tuple`. When returning a tuple, the first element is the sample tensor. 209 | 210 | """ 211 | 212 | if ( 213 | isinstance(timestep, int) 214 | or isinstance(timestep, torch.IntTensor) 215 | or isinstance(timestep, torch.LongTensor) 216 | ): 217 | raise ValueError( 218 | ( 219 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 220 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 221 | " one of the `scheduler.timesteps` as a timestep." 222 | ), 223 | ) 224 | 225 | if not self.is_scale_input_called: 226 | logger.warning( 227 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 228 | "See `StableDiffusionPipeline` for a usage example." 229 | ) 230 | 231 | if isinstance(timestep, torch.Tensor): 232 | timestep = timestep.to(self.timesteps.device) 233 | 234 | step_index = (self.timesteps == timestep).nonzero().item() 235 | sigma = self.sigmas[step_index] 236 | 237 | gamma = ( 238 | min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) 239 | if s_tmin <= sigma <= s_tmax 240 | else 0.0 241 | ) 242 | 243 | device = model_output.device 244 | if noise_type == "random": 245 | noise = randn_tensor( 246 | model_output.shape, 247 | dtype=model_output.dtype, 248 | device=device, 249 | generator=generator, 250 | ) 251 | elif noise_type == "video_fusion": 252 | noise = video_fusion_noise( 253 | model_output, w_ind_noise=w_ind_noise, generator=generator 254 | ) 255 | eps = noise * s_noise 256 | sigma_hat = sigma * (gamma + 1) 257 | 258 | if gamma > 0: 259 | sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 260 | 261 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 262 | # NOTE: "original_sample" should not be an expected prediction_type but is left in for 263 | # backwards compatibility 264 | if ( 265 | self.config.prediction_type == "original_sample" 266 | or self.config.prediction_type == "sample" 267 | ): 268 | pred_original_sample = model_output 269 | elif self.config.prediction_type == "epsilon": 270 | pred_original_sample = sample - sigma_hat * model_output 271 | elif self.config.prediction_type == "v_prediction": 272 | # * c_out + input * c_skip 273 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( 274 | sample / (sigma**2 + 1) 275 | ) 276 | else: 277 | raise ValueError( 278 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 279 | ) 280 | 281 | # 2. Convert to an ODE derivative 282 | derivative = (sample - pred_original_sample) / sigma_hat 283 | 284 | dt = self.sigmas[step_index + 1] - sigma_hat 285 | 286 | prev_sample = sample + derivative * dt 287 | 288 | if not return_dict: 289 | return (prev_sample,) 290 | 291 | return EulerDiscreteSchedulerOutput( 292 | prev_sample=prev_sample, pred_original_sample=pred_original_sample 293 | ) 294 | -------------------------------------------------------------------------------- /musev/schedulers/scheduling_lcm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion 16 | # and https://github.com/hojonathanho/diffusion 17 | from __future__ import annotations 18 | 19 | import math 20 | from dataclasses import dataclass 21 | from typing import List, Optional, Tuple, Union 22 | 23 | import numpy as np 24 | import torch 25 | from numpy import ndarray 26 | 27 | from diffusers.configuration_utils import ConfigMixin, register_to_config 28 | from diffusers.utils import BaseOutput, logging 29 | from diffusers.utils.torch_utils import randn_tensor 30 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 31 | from diffusers.schedulers.scheduling_lcm import ( 32 | LCMSchedulerOutput, 33 | betas_for_alpha_bar, 34 | rescale_zero_terminal_snr, 35 | LCMScheduler as DiffusersLCMScheduler, 36 | ) 37 | from ..utils.noise_util import video_fusion_noise 38 | 39 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 40 | 41 | 42 | class LCMScheduler(DiffusersLCMScheduler): 43 | def __init__( 44 | self, 45 | num_train_timesteps: int = 1000, 46 | beta_start: float = 0.00085, 47 | beta_end: float = 0.012, 48 | beta_schedule: str = "scaled_linear", 49 | trained_betas: ndarray | List[float] | None = None, 50 | original_inference_steps: int = 50, 51 | clip_sample: bool = False, 52 | clip_sample_range: float = 1, 53 | set_alpha_to_one: bool = True, 54 | steps_offset: int = 0, 55 | prediction_type: str = "epsilon", 56 | thresholding: bool = False, 57 | dynamic_thresholding_ratio: float = 0.995, 58 | sample_max_value: float = 1, 59 | timestep_spacing: str = "leading", 60 | timestep_scaling: float = 10, 61 | rescale_betas_zero_snr: bool = False, 62 | ): 63 | super().__init__( 64 | num_train_timesteps, 65 | beta_start, 66 | beta_end, 67 | beta_schedule, 68 | trained_betas, 69 | original_inference_steps, 70 | clip_sample, 71 | clip_sample_range, 72 | set_alpha_to_one, 73 | steps_offset, 74 | prediction_type, 75 | thresholding, 76 | dynamic_thresholding_ratio, 77 | sample_max_value, 78 | timestep_spacing, 79 | timestep_scaling, 80 | rescale_betas_zero_snr, 81 | ) 82 | 83 | def step( 84 | self, 85 | model_output: torch.FloatTensor, 86 | timestep: int, 87 | sample: torch.FloatTensor, 88 | generator: Optional[torch.Generator] = None, 89 | return_dict: bool = True, 90 | w_ind_noise: float = 0.5, 91 | noise_type: str = "random", 92 | ) -> Union[LCMSchedulerOutput, Tuple]: 93 | """ 94 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 95 | process from the learned model outputs (most often the predicted noise). 96 | 97 | Args: 98 | model_output (`torch.FloatTensor`): 99 | The direct output from learned diffusion model. 100 | timestep (`float`): 101 | The current discrete timestep in the diffusion chain. 102 | sample (`torch.FloatTensor`): 103 | A current instance of a sample created by the diffusion process. 104 | generator (`torch.Generator`, *optional*): 105 | A random number generator. 106 | return_dict (`bool`, *optional*, defaults to `True`): 107 | Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. 108 | Returns: 109 | [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: 110 | If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a 111 | tuple is returned where the first element is the sample tensor. 112 | """ 113 | if self.num_inference_steps is None: 114 | raise ValueError( 115 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 116 | ) 117 | 118 | if self.step_index is None: 119 | self._init_step_index(timestep) 120 | 121 | # 1. get previous step value 122 | prev_step_index = self.step_index + 1 123 | if prev_step_index < len(self.timesteps): 124 | prev_timestep = self.timesteps[prev_step_index] 125 | else: 126 | prev_timestep = timestep 127 | 128 | # 2. compute alphas, betas 129 | alpha_prod_t = self.alphas_cumprod[timestep] 130 | alpha_prod_t_prev = ( 131 | self.alphas_cumprod[prev_timestep] 132 | if prev_timestep >= 0 133 | else self.final_alpha_cumprod 134 | ) 135 | 136 | beta_prod_t = 1 - alpha_prod_t 137 | beta_prod_t_prev = 1 - alpha_prod_t_prev 138 | 139 | # 3. Get scalings for boundary conditions 140 | c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) 141 | 142 | # 4. Compute the predicted original sample x_0 based on the model parameterization 143 | if self.config.prediction_type == "epsilon": # noise-prediction 144 | predicted_original_sample = ( 145 | sample - beta_prod_t.sqrt() * model_output 146 | ) / alpha_prod_t.sqrt() 147 | elif self.config.prediction_type == "sample": # x-prediction 148 | predicted_original_sample = model_output 149 | elif self.config.prediction_type == "v_prediction": # v-prediction 150 | predicted_original_sample = ( 151 | alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output 152 | ) 153 | else: 154 | raise ValueError( 155 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" 156 | " `v_prediction` for `LCMScheduler`." 157 | ) 158 | 159 | # 5. Clip or threshold "predicted x_0" 160 | if self.config.thresholding: 161 | predicted_original_sample = self._threshold_sample( 162 | predicted_original_sample 163 | ) 164 | elif self.config.clip_sample: 165 | predicted_original_sample = predicted_original_sample.clamp( 166 | -self.config.clip_sample_range, self.config.clip_sample_range 167 | ) 168 | 169 | # 6. Denoise model output using boundary conditions 170 | denoised = c_out * predicted_original_sample + c_skip * sample 171 | 172 | # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference 173 | # Noise is not used on the final timestep of the timestep schedule. 174 | # This also means that noise is not used for one-step sampling. 175 | device = model_output.device 176 | 177 | if self.step_index != self.num_inference_steps - 1: 178 | if noise_type == "random": 179 | noise = randn_tensor( 180 | model_output.shape, 181 | dtype=model_output.dtype, 182 | device=device, 183 | generator=generator, 184 | ) 185 | elif noise_type == "video_fusion": 186 | noise = video_fusion_noise( 187 | model_output, w_ind_noise=w_ind_noise, generator=generator 188 | ) 189 | prev_sample = ( 190 | alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise 191 | ) 192 | else: 193 | prev_sample = denoised 194 | 195 | # upon completion increase step index by one 196 | self._step_index += 1 197 | 198 | if not return_dict: 199 | return (prev_sample, denoised) 200 | 201 | return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) 202 | 203 | def step_bk( 204 | self, 205 | model_output: torch.FloatTensor, 206 | timestep: int, 207 | sample: torch.FloatTensor, 208 | generator: Optional[torch.Generator] = None, 209 | return_dict: bool = True, 210 | ) -> Union[LCMSchedulerOutput, Tuple]: 211 | """ 212 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 213 | process from the learned model outputs (most often the predicted noise). 214 | 215 | Args: 216 | model_output (`torch.FloatTensor`): 217 | The direct output from learned diffusion model. 218 | timestep (`float`): 219 | The current discrete timestep in the diffusion chain. 220 | sample (`torch.FloatTensor`): 221 | A current instance of a sample created by the diffusion process. 222 | generator (`torch.Generator`, *optional*): 223 | A random number generator. 224 | return_dict (`bool`, *optional*, defaults to `True`): 225 | Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. 226 | Returns: 227 | [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: 228 | If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a 229 | tuple is returned where the first element is the sample tensor. 230 | """ 231 | if self.num_inference_steps is None: 232 | raise ValueError( 233 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 234 | ) 235 | 236 | if self.step_index is None: 237 | self._init_step_index(timestep) 238 | 239 | # 1. get previous step value 240 | prev_step_index = self.step_index + 1 241 | if prev_step_index < len(self.timesteps): 242 | prev_timestep = self.timesteps[prev_step_index] 243 | else: 244 | prev_timestep = timestep 245 | 246 | # 2. compute alphas, betas 247 | alpha_prod_t = self.alphas_cumprod[timestep] 248 | alpha_prod_t_prev = ( 249 | self.alphas_cumprod[prev_timestep] 250 | if prev_timestep >= 0 251 | else self.final_alpha_cumprod 252 | ) 253 | 254 | beta_prod_t = 1 - alpha_prod_t 255 | beta_prod_t_prev = 1 - alpha_prod_t_prev 256 | 257 | # 3. Get scalings for boundary conditions 258 | c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) 259 | 260 | # 4. Compute the predicted original sample x_0 based on the model parameterization 261 | if self.config.prediction_type == "epsilon": # noise-prediction 262 | predicted_original_sample = ( 263 | sample - beta_prod_t.sqrt() * model_output 264 | ) / alpha_prod_t.sqrt() 265 | elif self.config.prediction_type == "sample": # x-prediction 266 | predicted_original_sample = model_output 267 | elif self.config.prediction_type == "v_prediction": # v-prediction 268 | predicted_original_sample = ( 269 | alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output 270 | ) 271 | else: 272 | raise ValueError( 273 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" 274 | " `v_prediction` for `LCMScheduler`." 275 | ) 276 | 277 | # 5. Clip or threshold "predicted x_0" 278 | if self.config.thresholding: 279 | predicted_original_sample = self._threshold_sample( 280 | predicted_original_sample 281 | ) 282 | elif self.config.clip_sample: 283 | predicted_original_sample = predicted_original_sample.clamp( 284 | -self.config.clip_sample_range, self.config.clip_sample_range 285 | ) 286 | 287 | # 6. Denoise model output using boundary conditions 288 | denoised = c_out * predicted_original_sample + c_skip * sample 289 | 290 | # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference 291 | # Noise is not used on the final timestep of the timestep schedule. 292 | # This also means that noise is not used for one-step sampling. 293 | if self.step_index != self.num_inference_steps - 1: 294 | noise = randn_tensor( 295 | model_output.shape, 296 | generator=generator, 297 | device=model_output.device, 298 | dtype=denoised.dtype, 299 | ) 300 | prev_sample = ( 301 | alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise 302 | ) 303 | else: 304 | prev_sample = denoised 305 | 306 | # upon completion increase step index by one 307 | self._step_index += 1 308 | 309 | if not return_dict: 310 | return (prev_sample, denoised) 311 | 312 | return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) 313 | -------------------------------------------------------------------------------- /musev/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TMElyralab/MuseV/43370a6215afdfcd6d6af404350c132cd3b6eef8/musev/utils/__init__.py -------------------------------------------------------------------------------- /musev/utils/attention_util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, Literal 2 | 3 | from einops import repeat 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def get_diags_indices( 9 | shape: Union[int, Tuple[int, int]], k_min: int = 0, k_max: int = 0 10 | ): 11 | if isinstance(shape, int): 12 | shape = (shape, shape) 13 | rows, cols = np.indices(shape) 14 | diag = cols - rows 15 | return np.where((diag >= k_min) & (diag <= k_max)) 16 | 17 | 18 | def generate_mask_from_indices( 19 | shape: Tuple[int, int], 20 | indices: Tuple[np.ndarray, np.ndarray], 21 | big_value: float = 0, 22 | small_value: float = -1e9, 23 | ): 24 | matrix = np.ones(shape) * small_value 25 | matrix[indices] = big_value 26 | return matrix 27 | 28 | 29 | def generate_sparse_causcal_attn_mask( 30 | batch_size: int, 31 | n: int, 32 | n_near: int = 1, 33 | big_value: float = 0, 34 | small_value: float = -1e9, 35 | out_type: Literal["torch", "numpy"] = "numpy", 36 | expand: int = 1, 37 | ) -> np.ndarray: 38 | """generate b (n expand) (n expand) mask, 39 | where value of diag (0<=<=n_near) and first column of shape mat (n n) is set as big_value, others as small value 40 | expand的概念: 41 | attn 是 b n d 时,mask 是 b n n, 当 attn 是 b (expand n) d 时, mask 是 b (n expand) (n expand) 42 | Args: 43 | batch_size (int): _description_ 44 | n (int): _description_ 45 | n_near (int, optional): _description_. Defaults to 1. 46 | big_value (float, optional): _description_. Defaults to 0. 47 | small_value (float, optional): _description_. Defaults to -1e9. 48 | out_type (Literal["torch", "numpy"], optional): _description_. Defaults to "numpy". 49 | expand (int, optional): _description_. Defaults to 1. 50 | 51 | Returns: 52 | np.ndarray: _description_ 53 | """ 54 | shape = (n, n) 55 | diag_indices = get_diags_indices(n, k_min=-n_near, k_max=0) 56 | first_column = (np.arange(n), np.zeros(n).astype(np.int)) 57 | indices = ( 58 | np.concatenate([diag_indices[0], first_column[0]]), 59 | np.concatenate([diag_indices[1], first_column[1]]), 60 | ) 61 | mask = generate_mask_from_indices( 62 | shape=shape, indices=indices, big_value=big_value, small_value=small_value 63 | ) 64 | mask = repeat(mask, "m n-> b m n", b=batch_size) 65 | if expand > 1: 66 | mask = repeat( 67 | mask, 68 | "b m n -> b (m d1) (n d2)", 69 | d1=expand, 70 | d2=expand, 71 | ) 72 | if out_type == "torch": 73 | mask = torch.from_numpy(mask) 74 | return mask 75 | -------------------------------------------------------------------------------- /musev/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ Conversion script for the LoRA's safetensors checkpoints. """ 17 | 18 | import argparse 19 | 20 | import torch 21 | from safetensors.torch import load_file 22 | 23 | from diffusers import StableDiffusionPipeline 24 | import pdb 25 | 26 | 27 | 28 | def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0): 29 | # directly update weight in diffusers model 30 | for key in state_dict: 31 | # only process lora down key 32 | if "up." in key: continue 33 | 34 | up_key = key.replace(".down.", ".up.") 35 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 36 | model_key = model_key.replace("to_out.", "to_out.0.") 37 | layer_infos = model_key.split(".")[:-1] 38 | 39 | curr_layer = pipeline.unet 40 | while len(layer_infos) > 0: 41 | temp_name = layer_infos.pop(0) 42 | curr_layer = curr_layer.__getattr__(temp_name) 43 | 44 | weight_down = state_dict[key] 45 | weight_up = state_dict[up_key] 46 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 47 | 48 | return pipeline 49 | 50 | 51 | 52 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 53 | # load base model 54 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 55 | 56 | # load LoRA weight from .safetensors 57 | # state_dict = load_file(checkpoint_path) 58 | 59 | visited = [] 60 | 61 | # directly update weight in diffusers model 62 | for key in state_dict: 63 | # it is suggested to print out the key, it usually will be something like below 64 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 65 | 66 | # as we have set the alpha beforehand, so just skip 67 | if ".alpha" in key or key in visited: 68 | continue 69 | 70 | if "text" in key: 71 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 72 | curr_layer = pipeline.text_encoder 73 | else: 74 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 75 | curr_layer = pipeline.unet 76 | 77 | # find the target layer 78 | temp_name = layer_infos.pop(0) 79 | while len(layer_infos) > -1: 80 | try: 81 | curr_layer = curr_layer.__getattr__(temp_name) 82 | if len(layer_infos) > 0: 83 | temp_name = layer_infos.pop(0) 84 | elif len(layer_infos) == 0: 85 | break 86 | except Exception: 87 | if len(temp_name) > 0: 88 | temp_name += "_" + layer_infos.pop(0) 89 | else: 90 | temp_name = layer_infos.pop(0) 91 | 92 | pair_keys = [] 93 | if "lora_down" in key: 94 | pair_keys.append(key.replace("lora_down", "lora_up")) 95 | pair_keys.append(key) 96 | else: 97 | pair_keys.append(key) 98 | pair_keys.append(key.replace("lora_up", "lora_down")) 99 | 100 | # update weight 101 | if len(state_dict[pair_keys[0]].shape) == 4: 102 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 103 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 104 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 105 | else: 106 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 107 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 108 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 109 | 110 | # update visited list 111 | for item in pair_keys: 112 | visited.append(item) 113 | 114 | return pipeline 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser() 119 | 120 | parser.add_argument( 121 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 122 | ) 123 | parser.add_argument( 124 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 125 | ) 126 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 127 | parser.add_argument( 128 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 129 | ) 130 | parser.add_argument( 131 | "--lora_prefix_text_encoder", 132 | default="lora_te", 133 | type=str, 134 | help="The prefix of text encoder weight in safetensors", 135 | ) 136 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 137 | parser.add_argument( 138 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 139 | ) 140 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 141 | 142 | args = parser.parse_args() 143 | 144 | base_model_path = args.base_model_path 145 | checkpoint_path = args.checkpoint_path 146 | dump_path = args.dump_path 147 | lora_prefix_unet = args.lora_prefix_unet 148 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 149 | alpha = args.alpha 150 | 151 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 152 | 153 | pipe = pipe.to(args.device) 154 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 155 | -------------------------------------------------------------------------------- /musev/utils/noise_util.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | import torch 3 | 4 | 5 | from diffusers.utils.torch_utils import randn_tensor 6 | 7 | 8 | def random_noise( 9 | tensor: torch.Tensor = None, 10 | shape: Tuple[int] = None, 11 | dtype: torch.dtype = None, 12 | device: torch.device = None, 13 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 14 | noise_offset: Optional[float] = None, # typical value is 0.1 15 | ) -> torch.Tensor: 16 | if tensor is not None: 17 | shape = tensor.shape 18 | device = tensor.device 19 | dtype = tensor.dtype 20 | if isinstance(device, str): 21 | device = torch.device(device) 22 | noise = randn_tensor(shape, dtype=dtype, device=device, generator=generator) 23 | if noise_offset is not None: 24 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 25 | noise += noise_offset * torch.randn( 26 | (tensor.shape[0], tensor.shape[1], 1, 1, 1), device 27 | ) 28 | return noise 29 | 30 | 31 | def video_fusion_noise( 32 | tensor: torch.Tensor = None, 33 | shape: Tuple[int] = None, 34 | dtype: torch.dtype = None, 35 | device: torch.device = None, 36 | w_ind_noise: float = 0.5, 37 | generator: Optional[Union[List[torch.Generator], torch.Generator]] = None, 38 | initial_common_noise: torch.Tensor = None, 39 | ) -> torch.Tensor: 40 | if tensor is not None: 41 | shape = tensor.shape 42 | device = tensor.device 43 | dtype = tensor.dtype 44 | if isinstance(device, str): 45 | device = torch.device(device) 46 | batch_size, c, t, h, w = shape 47 | if isinstance(generator, list) and len(generator) != batch_size: 48 | raise ValueError( 49 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 50 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 51 | ) 52 | if not isinstance(generator, list): 53 | if initial_common_noise is not None: 54 | common_noise = initial_common_noise.to(device, dtype=dtype) 55 | else: 56 | common_noise = randn_tensor( 57 | (shape[0], shape[1], 1, shape[3], shape[4]), 58 | generator=generator, 59 | device=device, 60 | dtype=dtype, 61 | ) # common noise 62 | ind_noise = randn_tensor( 63 | shape, 64 | generator=generator, 65 | device=device, 66 | dtype=dtype, 67 | ) # individual noise 68 | s = torch.tensor(w_ind_noise, device=device, dtype=dtype) 69 | latents = torch.sqrt(1 - s) * common_noise + torch.sqrt(s) * ind_noise 70 | else: 71 | latents = [] 72 | for i in range(batch_size): 73 | latent = video_fusion_noise( 74 | shape=(1, c, t, h, w), 75 | dtype=dtype, 76 | device=device, 77 | w_ind_noise=w_ind_noise, 78 | generator=generator[i], 79 | initial_common_noise=initial_common_noise, 80 | ) 81 | latents.append(latent) 82 | latents = torch.cat(latents, dim=0).to(device) 83 | return latents 84 | -------------------------------------------------------------------------------- /musev/utils/register.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | 5 | 6 | class Register: 7 | def __init__(self, registry_name): 8 | self._dict = {} 9 | self._name = registry_name 10 | 11 | def __setitem__(self, key, value): 12 | if not callable(value): 13 | raise Exception(f"Value of a Registry must be a callable!\nValue: {value}") 14 | # 优先使用自定义的name,其次使用类名或者函数名。 15 | if "name" in value.__dict__: 16 | key = value.name 17 | elif key is None: 18 | key = value.__name__ 19 | if key in self._dict: 20 | logger.warning("Key %s already in registry %s." % (key, self._name)) 21 | self._dict[key] = value 22 | 23 | def register(self, target): 24 | """Decorator to register a function or class.""" 25 | 26 | def add(key, value): 27 | self[key] = value 28 | return value 29 | 30 | if callable(target): 31 | # @reg.register 32 | return add(None, target) 33 | # @reg.register('alias') 34 | return lambda x: add(target, x) 35 | 36 | def __getitem__(self, key): 37 | return self._dict[key] 38 | 39 | def __contains__(self, key): 40 | return key in self._dict 41 | 42 | def keys(self): 43 | """key""" 44 | return self._dict.keys() 45 | -------------------------------------------------------------------------------- /musev/utils/tensor_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def generate_meshgrid_2d(h: int, w: int, device) -> torch.tensor: 6 | x = torch.linspace(-1, 1, h, device=device) 7 | y = torch.linspace(-1, 1, w, device=device) 8 | grid_x, grid_y = torch.meshgrid(x, y) 9 | grid = torch.stack([grid_x, grid_y], dim=2) 10 | return grid 11 | 12 | 13 | def his_match(src, dst): 14 | src = src * 255.0 15 | dst = dst * 255.0 16 | src = src.astype(np.uint8) 17 | dst = dst.astype(np.uint8) 18 | res = np.zeros_like(dst) 19 | 20 | cdf_src = np.zeros((3, 256)) 21 | cdf_dst = np.zeros((3, 256)) 22 | cdf_res = np.zeros((3, 256)) 23 | kw = dict(bins=256, range=(0, 256), density=True) 24 | for ch in range(3): 25 | his_src, _ = np.histogram(src[:, :, ch], **kw) 26 | hist_dst, _ = np.histogram(dst[:, :, ch], **kw) 27 | cdf_src[ch] = np.cumsum(his_src) 28 | cdf_dst[ch] = np.cumsum(hist_dst) 29 | index = np.searchsorted(cdf_src[ch], cdf_dst[ch], side="left") 30 | np.clip(index, 0, 255, out=index) 31 | res[:, :, ch] = index[dst[:, :, ch]] 32 | his_res, _ = np.histogram(res[:, :, ch], **kw) 33 | cdf_res[ch] = np.cumsum(his_res) 34 | return res / 255.0 35 | -------------------------------------------------------------------------------- /musev/utils/timesteps_util.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal 2 | import numpy as np 3 | 4 | 5 | def generate_parameters_with_timesteps( 6 | start: int, 7 | num: int, 8 | stop: int = None, 9 | method: Literal["linear", "two_stage", "three_stage", "fix_two_stage"] = "linear", 10 | n_fix_start: int = 3, 11 | ) -> List[float]: 12 | if stop is None or start == stop: 13 | params = [start] * num 14 | else: 15 | if method == "linear": 16 | params = generate_linear_parameters(start, stop, num) 17 | elif method == "two_stage": 18 | params = generate_two_stages_parameters(start, stop, num) 19 | elif method == "three_stage": 20 | params = generate_three_stages_parameters(start, stop, num) 21 | elif method == "fix_two_stage": 22 | params = generate_fix_two_stages_parameters(start, stop, num, n_fix_start) 23 | else: 24 | raise ValueError( 25 | f"now only support linear, two_stage, three_stage, but given{method}" 26 | ) 27 | return params 28 | 29 | 30 | def generate_linear_parameters(start, stop, num): 31 | parames = list( 32 | np.linspace( 33 | start=start, 34 | stop=stop, 35 | num=num, 36 | ) 37 | ) 38 | return parames 39 | 40 | 41 | def generate_two_stages_parameters(start, stop, num): 42 | num_start = num // 2 43 | num_end = num - num_start 44 | parames = [start] * num_start + [stop] * num_end 45 | return parames 46 | 47 | 48 | def generate_fix_two_stages_parameters(start, stop, num, n_fix_start: int) -> List: 49 | num_start = n_fix_start 50 | num_end = num - num_start 51 | parames = [start] * num_start + [stop] * num_end 52 | return parames 53 | 54 | 55 | def generate_three_stages_parameters(start, stop, num): 56 | middle = (start + stop) // 2 57 | num_start = num // 3 58 | num_middle = num_start 59 | num_end = num - num_start - num_middle 60 | parames = [start] * num_start + [middle] * num_middle + [stop] * num_end 61 | return parames 62 | -------------------------------------------------------------------------------- /musev/utils/vae_util.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | 3 | from torch import nn 4 | import torch 5 | 6 | 7 | def decode_unet_latents_with_vae(vae: nn.Module, latents: torch.tensor): 8 | n_dim = latents.ndim 9 | batch_size = latents.shape[0] 10 | if n_dim == 5: 11 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 12 | latents = 1 / vae.config.scaling_factor * latents 13 | video = vae.decode(latents, return_dict=False)[0] 14 | video = (video / 2 + 0.5).clamp(0, 1) 15 | if n_dim == 5: 16 | latents = rearrange(latents, "(b f) h w c -> b c f h w", b=batch_size) 17 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 18 | return video 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers @ git+https://github.com/TMElyralab/diffusers.git@tme 2 | ip_adapter @ git+https://github.com/tencent-ailab/IP-Adapter.git@main 3 | clip @ git+https://github.com/openai/CLIP.git@main 4 | controlnet_aux @ git+https://github.com/TMElyralab/controlnet_aux.git@tme 5 | mmcm @ git+https://github.com/TMElyralab/MMCM.git@setup 6 | # tensorflow==2.12.0 7 | # tensorboard==2.12.0 8 | # torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118 9 | # torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118 10 | # torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 11 | torch 12 | torchvision 13 | torchaudio 14 | ninja==1.11.1 15 | transformers==4.33.1 16 | bitsandbytes==0.41.1 17 | decord==0.6.0 18 | accelerate==0.22.0 19 | xformers==0.0.21 20 | omegaconf 21 | einops 22 | imageio==2.31.1 23 | pandas 24 | h5py 25 | matplotlib 26 | modelcards==0.1.6 27 | pynvml==11.5.0 28 | black 29 | pytest 30 | moviepy==1.0.3 31 | torch-tb-profiler==0.4.1 32 | scikit-learn 33 | librosa 34 | ffmpeg 35 | easydict 36 | webp 37 | mediapipe==0.10.3 38 | cython==3.0.2 39 | easydict 40 | gdown 41 | insightface==0.7.3 42 | ipython 43 | librosa==0.10.1 44 | onnx==1.14.1 45 | onnxruntime==1.15.1 46 | onnxsim==0.4.33 47 | opencv_python 48 | Pillow 49 | protobuf==3.20.3 50 | pytube==15.0.0 51 | PyYAML 52 | requests 53 | scipy 54 | six 55 | tqdm 56 | albumentations==1.3.1 57 | opencv-contrib-python==4.8.0.76 58 | imageio-ffmpeg==0.4.8 59 | pytorch-lightning==2.0.8 60 | test-tube==0.7.5 61 | timm==0.9.12 62 | addict 63 | yapf 64 | prettytable 65 | safetensors==0.3.3 66 | fvcore 67 | pycocotools 68 | wandb==0.15.10 69 | wget 70 | ffmpeg-python 71 | streamlit 72 | webdataset 73 | kornia==0.7.0 74 | open_clip_torch==2.20.0 75 | streamlit-drawable-canvas==0.9.3 76 | torchmetrics==1.1.1 77 | invisible-watermark==0.1.5 78 | gdown==4.5.3 79 | ftfy==6.1.1 80 | jupyters 81 | ipywidgets==8.0.3 82 | ipython 83 | matplotlib==3.6.2 84 | redis==4.5.1 85 | # pydantic~=1.0 86 | gradio 87 | loguru==0.6.0 88 | IProgress==0.4 89 | markupsafe==2.0.1 90 | xlsxwriter 91 | cuid 92 | spaces 93 | # https://mirrors.cloud.tencent.com/pypi/packages/de/a6/a49d5af79a515f5c9552a26b2078d839c40fcf8dccc0d94a1269276ab181/tb_nightly-2.1.0a20191022-py3-none-any.whl 94 | basicsr 95 | # numpy==1.23.5 -------------------------------------------------------------------------------- /scripts/gradio/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM anchorxia/musev:latest 2 | 3 | #MAINTAINER 维护者信息 4 | LABEL MAINTAINER="anchorxia, zhanchao" 5 | LABEL Email="anchorxia@tencent.com, zhanchao019@foxmail.com" 6 | LABEL Description="musev gradio image, from docker pull anchorxia/musev:latest" 7 | 8 | SHELL ["/bin/bash", "--login", "-c"] 9 | 10 | # Set up a new user named "user" with user ID 1000 11 | RUN useradd -m -u 1000 user 12 | 13 | # Switch to the "user" user 14 | USER user 15 | 16 | # Set home to the user's home directory 17 | ENV HOME=/home/user \ 18 | PATH=/home/user/.local/bin:$PATH 19 | 20 | # Set the working directory to the user's home directory 21 | WORKDIR $HOME/app 22 | 23 | RUN echo "docker start"\ 24 | && whoami \ 25 | && which python \ 26 | && pwd 27 | 28 | RUN git clone -b hg_space --recursive https://github.com/TMElyralab/MuseV.git 29 | # RUN mkdir ./MuseV/checkpoints \ 30 | # && ls -l ./MuseV 31 | RUN chmod -R 777 /home/user/app/MuseV 32 | 33 | # RUN git clone -b main https://huggingface.co/TMElyralab/MuseV /home/user/app/MuseV/checkpoints 34 | 35 | RUN . /opt/conda/etc/profile.d/conda.sh \ 36 | && echo "source activate musev" >> ~/.bashrc \ 37 | && conda activate musev \ 38 | && conda env list 39 | 40 | RUN echo "export PYTHONPATH=\${PYTHONPATH}:/home/user/app/MuseV:/home/user/app/MuseV/MMCM:/home/user/app/MuseV/diffusers/src:/home/user/app/MuseV/controlnet_aux/src" >> ~/.bashrc 41 | 42 | WORKDIR /home/user/app/MuseV/scripts/gradio/ 43 | 44 | # Add entrypoint script 45 | COPY --chown=user entrypoint.sh ./entrypoint.sh 46 | RUN chmod +x ./entrypoint.sh 47 | RUN ls -l ./ 48 | 49 | EXPOSE 7860 50 | 51 | # CMD ["/bin/bash", "-c", "python app.py"] 52 | CMD ["./entrypoint.sh"] -------------------------------------------------------------------------------- /scripts/gradio/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "entrypoint.sh" 4 | whoami 5 | which python 6 | export PYTHONPATH=${PYTHONPATH}:/home/user/app/MuseV:/home/user/app/MuseV/MMCM:/home/user/app/MuseV/diffusers/src:/home/user/app/MuseV/controlnet_aux/src 7 | echo "pythonpath" $PYTHONPATH 8 | # chmod 777 -R /home/user/app/MuseV 9 | # Print the contents of the diffusers/src directory 10 | # echo "Contents of /home/user/app/MuseV/diffusers/src:" 11 | # Load ~/.bashrc 12 | # source ~/.bashrc 13 | 14 | source /opt/conda/etc/profile.d/conda.sh 15 | conda activate musev 16 | which python 17 | python ap_space.py -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import subprocess 3 | import os 4 | import pkg_resources 5 | 6 | from setuptools import setup, find_packages 7 | 8 | ProjectDir = os.path.dirname(__file__) 9 | result = subprocess.run(["pip", "install", "basicsr"], capture_output=True, text=True) 10 | result = subprocess.run( 11 | ["pip", "install", "--no-cache-dir", "-U", "openmim"], 12 | capture_output=True, 13 | text=True, 14 | ) 15 | result = subprocess.run(["mim", "install", "mmengine"], capture_output=True, text=True) 16 | result = subprocess.run( 17 | ["mim", "install", "mmcv>=2.0.1"], capture_output=True, text=True 18 | ) 19 | result = subprocess.run( 20 | ["mim", "install", "mmdet>=3.1.0"], capture_output=True, text=True 21 | ) 22 | result = subprocess.run( 23 | ["mim", "install", "mmpose>=1.1.0"], capture_output=True, text=True 24 | ) 25 | 26 | with open(os.path.join(ProjectDir, "requirements.txt"), "r") as f: 27 | requirements = f.read().splitlines() 28 | requirements = [x for x in requirements if x and not x.startswith("#")] 29 | requirements = [x.split(" ")[0] if "index-url" in x else x for x in requirements] 30 | 31 | setup( 32 | name="musev", # used in pip install 33 | version="1.0.0", 34 | author="anchorxia, zkangchen", 35 | author_email="anchorxia@tencent.com, zkangchen@tencent.com", 36 | description="Package about human video creation", 37 | # long_description=long_description, 38 | # long_description_content_type="text/markdown", 39 | url="https://github.com/TMElyralab/MuseV", 40 | packages=find_packages("musev"), 41 | package_dir={"": "musev"}, 42 | # include_package_data=True, # please edit MANIFEST.in 43 | classifiers=[ 44 | "Programming Language :: Python :: 3", 45 | "License :: OSI Approved :: MIT License", 46 | "Operating System :: OS Independent", 47 | ], 48 | install_requires=requirements, 49 | # dependency_links=["https://download.pytorch.org/whl/cu118], 50 | ) 51 | --------------------------------------------------------------------------------