├── .github ├── FUNDING.yml └── workflows │ └── publish_to_registry.yml ├── IF_MemoAvatar.py ├── IF_MemoCheckpointLoader.py ├── LICENSE ├── README.md ├── __init__.py ├── configs └── inference.yaml ├── examples ├── candy.wav ├── candy@2x.png ├── dicaprio.jpg └── speech.wav ├── memo ├── __init__.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── attention_processor.py │ ├── audio_proj.py │ ├── emotion_classifier.py │ ├── image_proj.py │ ├── motion_module.py │ ├── normalization.py │ ├── resnet.py │ ├── transformer_2d.py │ ├── transformer_3d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_condition.py │ ├── unet_3d.py │ ├── unet_3d_blocks.py │ └── wav2vec.py ├── pipelines │ ├── __init__.py │ └── video_pipeline.py └── utils │ ├── __init__.py │ ├── audio_utils.py │ └── vision_utils.py ├── memo_model_manager.py ├── pyproject.toml ├── requirements.txt ├── web └── js │ └── IF_MemoAvatar.js └── workflow ├── IF_MemoAvatar.json ├── IF_MemoAvatar_IF_Extensions.json └── IF_MemoAvatar_simple.json /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: ImpactFrames 2 | patreon: ImpactFrames 3 | ko_fi: impactframes 4 | -------------------------------------------------------------------------------- /.github/workflows/publish_to_registry.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | jobs: 11 | publish-node: 12 | name: Publish Custom Node to registry 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | - name: Publish Custom Node 18 | uses: Comfy-Org/publish-node-action@main 19 | with: 20 | ## Add your own personal access token to your Github Repository secrets and reference it here. 21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 22 | -------------------------------------------------------------------------------- /IF_MemoAvatar.py: -------------------------------------------------------------------------------- 1 | #IF_MemoAvatar.py 2 | import os 3 | import torch 4 | import numpy as np 5 | import torchaudio 6 | from PIL import Image 7 | import logging 8 | from tqdm import tqdm 9 | import time 10 | from contextlib import contextmanager 11 | 12 | import folder_paths 13 | import comfy.model_management 14 | from diffusers import FlowMatchEulerDiscreteScheduler 15 | from diffusers.utils import is_xformers_available 16 | 17 | from memo.pipelines.video_pipeline import VideoPipeline 18 | from memo.utils.audio_utils import extract_audio_emotion_labels, preprocess_audio, resample_audio 19 | from memo.utils.vision_utils import preprocess_image, tensor_to_video 20 | from memo_model_manager import MemoModelManager 21 | 22 | logger = logging.getLogger("memo") 23 | 24 | class IF_MemoAvatar: 25 | @classmethod 26 | def INPUT_TYPES(s): 27 | return { 28 | "required": { 29 | "image": ("IMAGE",), 30 | "audio": ("AUDIO",), 31 | "reference_net": ("MODEL",), 32 | "diffusion_net": ("MODEL",), 33 | "vae": ("VAE",), 34 | "image_proj": ("IMAGE_PROJ",), 35 | "audio_proj": ("AUDIO_PROJ",), 36 | "emotion_classifier": ("EMOTION_CLASSIFIER",), 37 | "resolution": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 8}), 38 | "num_frames_per_clip": ("INT", {"default": 16, "min": 1, "max": 32}), 39 | "fps": ("INT", {"default": 30, "min": 1, "max": 60}), 40 | "inference_steps": ("INT", {"default": 20, "min": 1, "max": 100}), 41 | "cfg_scale": ("FLOAT", {"default": 3.5, "min": 1.0, "max": 100.0}), 42 | "seed": ("INT", {"default": 42}), 43 | "output_name": ("STRING", {"default": "memo_video"}) 44 | } 45 | } 46 | 47 | RETURN_TYPES = ("STRING", "STRING") 48 | RETURN_NAMES = ("video_path", "status") 49 | FUNCTION = "generate" 50 | CATEGORY = "ImpactFrames💥🎞️/MemoAvatar" 51 | 52 | def __init__(self): 53 | self.device = comfy.model_management.get_torch_device() 54 | # Use bfloat16 if available, fallback to float16 55 | self.dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 56 | 57 | # Initialize model manager and get paths 58 | self.model_manager = MemoModelManager() 59 | self.paths = self.model_manager.get_model_paths() 60 | 61 | 62 | def generate(self, image, audio, reference_net, diffusion_net, vae, image_proj, audio_proj, 63 | emotion_classifier, resolution=512, num_frames_per_clip=16, fps=30, 64 | inference_steps=20, cfg_scale=3.5, seed=42, output_name="memo_video"): 65 | try: 66 | # Save video 67 | timestamp = time.strftime('%Y%m%d-%H%M%S') 68 | video_name = f"{output_name}_{timestamp}.mp4" 69 | output_dir = folder_paths.get_output_directory() 70 | video_path = os.path.join(output_dir, video_name) 71 | 72 | # Memory optimizations 73 | torch.cuda.empty_cache() 74 | if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): 75 | autocast = torch.cuda.amp.autocast 76 | else: 77 | @contextmanager 78 | def autocast(): 79 | yield 80 | 81 | num_init_past_frames = 2 82 | num_past_frames = 16 83 | 84 | # Save input image temporarily 85 | temp_dir = folder_paths.get_temp_directory() 86 | temp_image = os.path.join(temp_dir, f"ref_image_{time.time()}.png") 87 | 88 | try: 89 | # Convert ComfyUI image format to PIL 90 | if isinstance(image, torch.Tensor): 91 | image = image.cpu().numpy() 92 | if image.ndim == 4: 93 | image = image[0] 94 | image = Image.fromarray((image * 255).astype(np.uint8)) 95 | image.save(temp_image) 96 | face_models_path = self.paths["face_models"] 97 | print(f"face_models path: {face_models_path}") 98 | # Process image with our models 99 | pixel_values, face_emb = preprocess_image( 100 | self.paths["face_models"], # face_analysis_model 101 | temp_image, # image_path 102 | resolution # image_size 103 | ) 104 | finally: 105 | if os.path.exists(temp_image): 106 | os.remove(temp_image) 107 | 108 | # Save audio temporarily 109 | temp_dir = folder_paths.get_temp_directory() 110 | temp_audio = os.path.join(temp_dir, f"temp_audio_{time.time()}.wav") 111 | 112 | try: 113 | # Convert 3D tensor to 2D if necessary 114 | waveform = audio["waveform"] 115 | if waveform.ndim == 3: 116 | waveform = waveform.squeeze(0) # Remove batch dimension 117 | 118 | # Save the audio tensor to a temporary WAV file 119 | torchaudio.save(temp_audio, waveform, audio["sample_rate"]) 120 | 121 | # Set up audio cache directory 122 | cache_dir = os.path.join(folder_paths.get_temp_directory(), "memo_audio_cache") 123 | os.makedirs(cache_dir, exist_ok=True) 124 | 125 | resampled_path = os.path.join(cache_dir, f"resampled_{time.time()}-16k.wav") 126 | resampled_path = resample_audio(temp_audio, resampled_path) 127 | 128 | # Process audio 129 | audio_emb, audio_length = preprocess_audio( 130 | wav_path=resampled_path, 131 | num_generated_frames_per_clip=num_frames_per_clip, 132 | fps=fps, 133 | wav2vec_model=self.paths["wav2vec"], 134 | vocal_separator_model=self.paths["vocal_separator"], 135 | cache_dir=cache_dir, 136 | device=str(self.device) 137 | ) 138 | 139 | # Extract emotion 140 | audio_emotion, num_emotion_classes = extract_audio_emotion_labels( 141 | model=self.paths["memo_base"], 142 | wav_path=resampled_path, 143 | emotion2vec_model=self.paths["emotion2vec"], 144 | audio_length=audio_length, 145 | device=str(self.device) 146 | ) 147 | 148 | # Model optimizations 149 | vae.requires_grad_(False).eval() 150 | reference_net.requires_grad_(False).eval() 151 | diffusion_net.requires_grad_(False).eval() 152 | image_proj.requires_grad_(False).eval() 153 | audio_proj.requires_grad_(False).eval() 154 | 155 | # Enable memory efficient attention (Optional) 156 | if is_xformers_available(): 157 | try: 158 | reference_net.enable_xformers_memory_efficient_attention() 159 | diffusion_net.enable_xformers_memory_efficient_attention() 160 | except Exception as e: 161 | logger.warning( 162 | f"Could not enable memory efficient attention for xformers: {e}." 163 | "Do you have xformers installed? " 164 | "If you do, please check your xformers installation" 165 | ) 166 | 167 | # Create pipeline with optimizations 168 | noise_scheduler = FlowMatchEulerDiscreteScheduler() 169 | with torch.inference_mode(): 170 | pipeline = VideoPipeline( 171 | vae=vae, 172 | reference_net=reference_net, 173 | diffusion_net=diffusion_net, 174 | scheduler=noise_scheduler, 175 | image_proj=image_proj, 176 | ) 177 | pipeline.to(device=self.device, dtype=self.dtype) 178 | 179 | # Generate video frames with memory optimizations 180 | video_frames = [] 181 | num_clips = audio_emb.shape[0] // num_frames_per_clip 182 | generator = torch.Generator(device=self.device).manual_seed(seed) 183 | 184 | for t in tqdm(range(num_clips), desc="Generating video clips"): 185 | # Clear cache at the start of each iteration 186 | if torch.cuda.is_available(): 187 | torch.cuda.empty_cache() 188 | 189 | if len(video_frames) == 0: 190 | past_frames = pixel_values.repeat(num_init_past_frames, 1, 1, 1) 191 | past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device) 192 | pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0) 193 | else: 194 | past_frames = video_frames[-1][0] 195 | past_frames = past_frames.permute(1, 0, 2, 3) 196 | past_frames = past_frames[0 - num_past_frames:] 197 | past_frames = past_frames * 2.0 - 1.0 198 | past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device) 199 | pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0) 200 | 201 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) 202 | 203 | # Process audio in smaller chunks if needed 204 | audio_tensor = ( 205 | audio_emb[ 206 | t * num_frames_per_clip : min( 207 | (t + 1) * num_frames_per_clip, audio_emb.shape[0] 208 | ) 209 | ] 210 | .unsqueeze(0) 211 | .to(device=audio_proj.device, dtype=audio_proj.dtype) 212 | ) 213 | 214 | with torch.inference_mode(): 215 | audio_tensor = audio_proj(audio_tensor) 216 | 217 | audio_emotion_tensor = audio_emotion[ 218 | t * num_frames_per_clip : min( 219 | (t + 1) * num_frames_per_clip, audio_emb.shape[0] 220 | ) 221 | ] 222 | 223 | pipeline_output = pipeline( 224 | ref_image=pixel_values_ref_img, 225 | audio_tensor=audio_tensor, 226 | audio_emotion=audio_emotion_tensor, 227 | emotion_class_num=num_emotion_classes, 228 | face_emb=face_emb, 229 | width=resolution, 230 | height=resolution, 231 | video_length=num_frames_per_clip, 232 | num_inference_steps=inference_steps, 233 | guidance_scale=cfg_scale, 234 | generator=generator, 235 | ) 236 | 237 | video_frames.append(pipeline_output.videos) 238 | 239 | video_frames = torch.cat(video_frames, dim=2) 240 | video_frames = video_frames.squeeze(0) 241 | video_frames = video_frames[:, :audio_length] 242 | 243 | 244 | tensor_to_video(video_frames, video_path, temp_audio, fps=fps) 245 | return (video_path, f"✅ Video saved as {video_name}") 246 | 247 | finally: 248 | # Clean up temporary files 249 | if os.path.exists(temp_audio): 250 | os.remove(temp_audio) 251 | 252 | except Exception as e: 253 | import traceback 254 | traceback.print_exc() 255 | return ("", f"❌ Error: {str(e)}") 256 | 257 | # Node mappings 258 | NODE_CLASS_MAPPINGS = { 259 | "IF_MemoAvatar": IF_MemoAvatar 260 | } 261 | 262 | NODE_DISPLAY_NAME_MAPPINGS = { 263 | "IF_MemoAvatar": "IF MemoAvatar 🗣️" 264 | } -------------------------------------------------------------------------------- /IF_MemoCheckpointLoader.py: -------------------------------------------------------------------------------- 1 | #IF_MemoCheckpointLoader.py 2 | import os 3 | import torch 4 | import folder_paths 5 | import logging 6 | from diffusers import AutoencoderKL 7 | from diffusers.utils import is_xformers_available 8 | from packaging import version 9 | from safetensors.torch import load_file 10 | from huggingface_hub import hf_hub_download 11 | 12 | from memo.models.unet_2d_condition import UNet2DConditionModel 13 | from memo.models.unet_3d import UNet3DConditionModel 14 | from memo.models.image_proj import ImageProjModel 15 | from memo.models.audio_proj import AudioProjModel 16 | from memo.models.emotion_classifier import AudioEmotionClassifierModel 17 | from memo_model_manager import MemoModelManager 18 | 19 | logger = logging.getLogger("memo") 20 | 21 | class IF_MemoCheckpointLoader: 22 | @classmethod 23 | def INPUT_TYPES(s): 24 | return { 25 | "required": { 26 | "enable_xformers": ("BOOLEAN", {"default": True}), 27 | } 28 | } 29 | 30 | RETURN_TYPES = ("MODEL", "MODEL", "VAE", "IMAGE_PROJ", "AUDIO_PROJ", "EMOTION_CLASSIFIER") 31 | RETURN_NAMES = ("reference_net", "diffusion_net", "vae", "image_proj", "audio_proj", "emotion_classifier") 32 | FUNCTION = "load_checkpoint" 33 | CATEGORY = "ImpactFrames💥🎞️/MemoAvatar" 34 | 35 | def __init__(self): 36 | # Initialize model manager to set up all paths and auxiliary models 37 | self.model_manager = MemoModelManager() 38 | self.paths = self.model_manager.get_model_paths() 39 | 40 | def load_checkpoint(self, enable_xformers=True): 41 | try: 42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 43 | dtype = torch.float16 if str(device) == "cuda" else torch.float32 44 | 45 | logger.info("Loading models") 46 | 47 | # Fallback download function 48 | def fallback_download(repo_id, filename, local_dir): 49 | try: 50 | return hf_hub_download( 51 | repo_id=repo_id, 52 | filename=filename, 53 | local_dir=local_dir, 54 | local_dir_use_symlinks=False 55 | ) 56 | except Exception as e: 57 | logger.warning(f"Failed to download {filename} from {repo_id}: {e}") 58 | return None 59 | 60 | # Load VAE with multiple fallback strategies 61 | try: 62 | vae = AutoencoderKL.from_pretrained( 63 | self.paths["vae"], 64 | use_safetensors=True, 65 | torch_dtype=dtype 66 | ).to(device=device) 67 | except Exception as e: 68 | logger.warning(f"Local VAE load failed: {e}. Attempting download.") 69 | fallback_download( 70 | "stabilityai/sd-vae-ft-mse", 71 | "diffusion_pytorch_model.safetensors", 72 | self.paths["vae"] 73 | ) 74 | vae = AutoencoderKL.from_pretrained( 75 | "stabilityai/sd-vae-ft-mse", 76 | use_safetensors=True, 77 | torch_dtype=dtype 78 | ).to(device=device) 79 | 80 | # Load reference net 81 | reference_net = UNet2DConditionModel.from_pretrained( 82 | self.paths["memo_base"], 83 | subfolder="reference_net", 84 | use_safetensors=True 85 | ) 86 | reference_net.requires_grad_(False) 87 | reference_net.eval() 88 | 89 | # Load diffusion net 90 | diffusion_net = UNet3DConditionModel.from_pretrained( 91 | self.paths["memo_base"], 92 | subfolder="diffusion_net", 93 | use_safetensors=True 94 | ) 95 | diffusion_net.requires_grad_(False) 96 | diffusion_net.eval() 97 | 98 | # Load projectors 99 | image_proj = ImageProjModel.from_pretrained( 100 | self.paths["memo_base"], 101 | subfolder="image_proj", 102 | use_safetensors=True 103 | ) 104 | image_proj.requires_grad_(False) 105 | image_proj.eval() 106 | 107 | audio_proj = AudioProjModel.from_pretrained( 108 | self.paths["memo_base"], 109 | subfolder="audio_proj", 110 | use_safetensors=True 111 | ) 112 | audio_proj.requires_grad_(False) 113 | audio_proj.eval() 114 | 115 | # Enable xformers (Optional) 116 | if enable_xformers and is_xformers_available(): 117 | try: 118 | import xformers 119 | xformers_version = version.parse(xformers.__version__) 120 | if xformers_version == version.parse("0.0.16"): 121 | logger.warning("xFormers 0.0.16 cannot be used for training in some GPUs.") 122 | reference_net.enable_xformers_memory_efficient_attention() 123 | diffusion_net.enable_xformers_memory_efficient_attention() 124 | except Exception as e: 125 | logger.warning(f"Could not enable xformers: {e}. Proceeding without xformers.") 126 | else: 127 | logger.info("Xformers is not enabled or not available. Proceeding without xformers.") 128 | 129 | # Move models to device 130 | for model in [reference_net, diffusion_net, image_proj, audio_proj]: 131 | model.to(device=device, dtype=dtype) 132 | 133 | # Load emotion classifier 134 | emotion_classifier = AudioEmotionClassifierModel() 135 | emotion_classifier_path = os.path.join( 136 | self.paths["memo_base"], 137 | "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors" 138 | ) 139 | emotion_classifier.load_state_dict(load_file(emotion_classifier_path)) 140 | emotion_classifier.to(device=device, dtype=dtype) 141 | emotion_classifier.eval() 142 | 143 | logger.info(f"Models loaded successfully to {device} with dtype {dtype}") 144 | return (reference_net, diffusion_net, vae, image_proj, audio_proj, emotion_classifier) 145 | 146 | except Exception as e: 147 | logger.error(f"Comprehensive model loading error: {e}") 148 | import traceback 149 | traceback.print_exc() 150 | raise RuntimeError(f"Failed to load models: {str(e)}") 151 | 152 | @classmethod 153 | def IS_CHANGED(s, **kwargs): 154 | return float("nan") 155 | 156 | NODE_CLASS_MAPPINGS = { 157 | "IF_MemoCheckpointLoader": IF_MemoCheckpointLoader 158 | } 159 | 160 | NODE_DISPLAY_NAME_MAPPINGS = { 161 | "IF_MemoCheckpointLoader": "IF Memo Checkpoint Loader 🎬" 162 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 ImpactFrames 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-IF_MemoAvatar 2 | Memory-Guided Diffusion for Expressive Talking Video Generation 3 | 4 | 5 | ![demo](https://github.com/user-attachments/assets/f11caaf4-2345-4e69-b6b4-ffa895116f29) 6 | 7 | #ORIGINAL REPO 8 | **MEMO: Memory-Guided Diffusion for Expressive Talking Video Generation** 9 |
10 | [Longtao Zheng](https://ltzheng.github.io)\*, 11 | [Yifan Zhang](https://scholar.google.com/citations?user=zuYIUJEAAAAJ)\*, 12 | [Hanzhong Guo](https://scholar.google.com/citations?user=q3x6KsgAAAAJ)\, 13 | [Jiachun Pan](https://scholar.google.com/citations?user=nrOvfb4AAAAJ), 14 | [Zhenxiong Tan](https://scholar.google.com/citations?user=HP9Be6UAAAAJ), 15 | [Jiahao Lu](https://scholar.google.com/citations?user=h7rbA-sAAAAJ), 16 | [Chuanxin Tang](https://scholar.google.com/citations?user=3ZC8B7MAAAAJ), 17 | [Bo An](https://personal.ntu.edu.sg/boan/index.html), 18 | [Shuicheng Yan](https://scholar.google.com/citations?user=DNuiPHwAAAAJ) 19 |
20 | _[Project Page](https://memoavatar.github.io) | [arXiv](https://arxiv.org/abs/2412.04448) | [Model](https://huggingface.co/memoavatar/memo)_ 21 | 22 | This repository contains the example inference script for the MEMO-preview model. The gif demo below is compressed. See our [project page](https://memoavatar.github.io) for full videos. 23 | 24 |
25 | Demo GIF 26 |
27 | 28 | # ComfyUI-IF_MemoAvatar 29 | Memory-Guided Diffusion for Expressive Talking Video Generation 30 | 31 | ## Overview 32 | This is a ComfyUI implementation of MEMO (Memory-Guided Diffusion for Expressive Talking Video Generation), which enables the creation of expressive talking avatar videos from a single image and audio input. 33 | 34 | ## Features 35 | - Generate expressive talking head videos from a single image 36 | - Audio-driven facial animation 37 | - Emotional expression transfer 38 | - High-quality video output 39 | ![thorium_XMBCG9kbGn](https://github.com/user-attachments/assets/7d19178f-870c-429b-988e-1335fe1ba8f9) 40 | 41 | 42 | https://github.com/user-attachments/assets/bfbf896d-a609-4e0f-8ed3-16ec48f8d85a 43 | 44 | 45 | ## Installation 46 | 47 | *** Xformers NOT REQUIRED BUT BETTER IF INSTALLED*** 48 | *** MAKE SURE YoU HAVE HF Token On Your environment VARIABLES *** 49 | 50 | git clone the repo to your custom_nodes folder and then 51 | ```bash 52 | cd ComfyUI-IF_MemoAvatar 53 | pip install -r requirements.txt 54 | ``` 55 | I removed xformers from the file because it needs a particular combination of pytorch on windows to work 56 | 57 | if you are on linux you can just run 58 | ```bash 59 | pip install xformers 60 | ``` 61 | for windows users if you don't have xformers on your env 62 | ```bash 63 | pip show xformers 64 | ``` 65 | follow this guide to install a good comfyui environment if you don't see any version install the latest following this free guide 66 | 67 | [Installing Triton and Sage Attention Flash Attention](https://ko-fi.com/post/Installing-Triton-and-Sage-Attention-Flash-Attenti-P5P8175434) 68 | 69 | 70 | [![Watch the video](https://img.youtube.com/vi/nSUGEdm2wU4/hqdefault.jpg)](https://www.youtube.com/watch?v=nSUGEdm2wU4) 71 | 72 | 73 | ### Model Files 74 | The models will automatically download to the following locations in your ComfyUI installation: 75 | 76 | ```bash 77 | models/checkpoints/memo/ 78 | ├── audio_proj/ 79 | ├── diffusion_net/ 80 | ├── image_proj/ 81 | ├── misc/ 82 | │ ├── audio_emotion_classifier/ 83 | │ ├── face_analysis/ 84 | │ └── vocal_separator/ 85 | └── reference_net/ 86 | models/wav2vec/ 87 | models/vae/sd-vae-ft-mse/ 88 | models/emotion2vec/emotion2vec_plus_large/ 89 | 90 | ``` 91 | 92 | Copy the faceanalisys/models models from the folder directly into faceanalisys 93 | just until I make sure don't just move then duplicate them cos 94 | HF will detect empty and download them every time 95 | If you don't see a `models.json` or errors out create one yourself this is the content 96 | ```bash 97 | { 98 | "detection": [ 99 | "scrfd_10g_bnkps" 100 | ], 101 | "recognition": [ 102 | "glintr100" 103 | ], 104 | "analysis": [ 105 | "genderage", 106 | "2d106det", 107 | "1k3d68" 108 | ] 109 | } 110 | ``` 111 | and a `version.txt` containing 112 | `0.7.3` 113 | 114 | ![yW8hDQhnhM](https://github.com/user-attachments/assets/1c11e940-2da3-4d43-9453-cc1be06942c3) 115 | 116 | :IF_MemoAvatar_comfy 117 | 118 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #__init__.py 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | # Get the absolute path to the current directory and memo directory 7 | CURRENT_DIR = Path(__file__).parent.absolute() 8 | MEMO_DIR = CURRENT_DIR / "memo" 9 | 10 | # Add both directories to Python path if they're not already there 11 | if str(CURRENT_DIR) not in sys.path: 12 | sys.path.insert(0, str(CURRENT_DIR)) 13 | if str(MEMO_DIR) not in sys.path: 14 | sys.path.insert(0, str(MEMO_DIR)) 15 | 16 | # Create an empty __init__.py in memo directory if it doesn't exist 17 | memo_init = MEMO_DIR / "__init__.py" 18 | if not memo_init.exists(): 19 | memo_init.touch() 20 | 21 | # Now import the components using absolute imports 22 | from .memo_model_manager import MemoModelManager 23 | from .IF_MemoAvatar import IF_MemoAvatar 24 | from .IF_MemoCheckpointLoader import IF_MemoCheckpointLoader 25 | 26 | NODE_CLASS_MAPPINGS = { 27 | "IF_MemoAvatar": IF_MemoAvatar, 28 | "IF_MemoCheckpointLoader": IF_MemoCheckpointLoader, 29 | } 30 | 31 | NODE_DISPLAY_NAME_MAPPINGS = { 32 | "IF_MemoAvatar": "IF MemoAvatar 🗣️", 33 | "IF_MemoCheckpointLoader": "IF Memo Checkpoint Loader" 34 | } 35 | 36 | # Define web directory relative to this file 37 | WEB_DIRECTORY = os.path.join(os.path.dirname(__file__), "web") 38 | 39 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] -------------------------------------------------------------------------------- /configs/inference.yaml: -------------------------------------------------------------------------------- 1 | resolution: 512 2 | num_generated_frames_per_clip: 16 3 | fps: 30 4 | num_init_past_frames: 2 5 | num_past_frames: 16 6 | inference_steps: 20 7 | cfg_scale: 3.5 8 | weight_dtype: bf16 9 | enable_xformers_memory_efficient_attention: true 10 | 11 | model_name_or_path: memoavatar/memo 12 | # model_name_or_path: checkpoints 13 | vae: stabilityai/sd-vae-ft-mse 14 | wav2vec: facebook/wav2vec2-base-960h 15 | emotion2vec: emotion2vec/emotion2vec_plus_large 16 | misc_model_dir: checkpoints 17 | 18 | -------------------------------------------------------------------------------- /examples/candy.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/examples/candy.wav -------------------------------------------------------------------------------- /examples/candy@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/examples/candy@2x.png -------------------------------------------------------------------------------- /examples/dicaprio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/examples/dicaprio.jpg -------------------------------------------------------------------------------- /examples/speech.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/examples/speech.wav -------------------------------------------------------------------------------- /memo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/memo/__init__.py -------------------------------------------------------------------------------- /memo/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/memo/models/__init__.py -------------------------------------------------------------------------------- /memo/models/audio_proj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import ConfigMixin, ModelMixin 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | 7 | class AudioProjModel(ModelMixin, ConfigMixin): 8 | def __init__( 9 | self, 10 | seq_len=5, 11 | blocks=12, # add a new parameter blocks 12 | channels=768, # add a new parameter channels 13 | intermediate_dim=512, 14 | output_dim=768, 15 | context_tokens=32, 16 | ): 17 | super().__init__() 18 | 19 | self.seq_len = seq_len 20 | self.blocks = blocks 21 | self.channels = channels 22 | self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels. 23 | self.intermediate_dim = intermediate_dim 24 | self.context_tokens = context_tokens 25 | self.output_dim = output_dim 26 | 27 | # define multiple linear layers 28 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim) 29 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) 30 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) 31 | 32 | self.norm = nn.LayerNorm(output_dim) 33 | 34 | def forward(self, audio_embeds): 35 | video_length = audio_embeds.shape[1] 36 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") 37 | batch_size, window_size, blocks, channels = audio_embeds.shape 38 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) 39 | 40 | audio_embeds = torch.relu(self.proj1(audio_embeds)) 41 | audio_embeds = torch.relu(self.proj2(audio_embeds)) 42 | 43 | context_tokens = self.proj3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim) 44 | 45 | context_tokens = self.norm(context_tokens) 46 | context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) 47 | 48 | return context_tokens 49 | -------------------------------------------------------------------------------- /memo/models/emotion_classifier.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from diffusers import ConfigMixin, ModelMixin 5 | 6 | 7 | class AudioEmotionClassifierModel(ModelMixin, ConfigMixin): 8 | num_emotion_classes = 9 9 | 10 | def __init__(self, num_classifier_layers=5, num_classifier_channels=2048): 11 | super().__init__() 12 | 13 | if num_classifier_layers == 1: 14 | self.layers = torch.nn.Linear(1024, self.num_emotion_classes) 15 | else: 16 | layer_list = [ 17 | ("fc1", torch.nn.Linear(1024, num_classifier_channels)), 18 | ("relu1", torch.nn.ReLU()), 19 | ] 20 | for n in range(num_classifier_layers - 2): 21 | layer_list.append((f"fc{n+2}", torch.nn.Linear(num_classifier_channels, num_classifier_channels))) 22 | layer_list.append((f"relu{n+2}", torch.nn.ReLU())) 23 | layer_list.append( 24 | (f"fc{num_classifier_layers}", torch.nn.Linear(num_classifier_channels, self.num_emotion_classes)) 25 | ) 26 | self.layers = torch.nn.Sequential(OrderedDict(layer_list)) 27 | 28 | def forward(self, x): 29 | x = self.layers(x) 30 | x = torch.softmax(x, dim=-1) 31 | return x 32 | -------------------------------------------------------------------------------- /memo/models/image_proj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import ConfigMixin, ModelMixin 3 | 4 | 5 | class ImageProjModel(ModelMixin, ConfigMixin): 6 | def __init__( 7 | self, 8 | cross_attention_dim=768, 9 | clip_embeddings_dim=512, 10 | clip_extra_context_tokens=4, 11 | ): 12 | super().__init__() 13 | 14 | self.generator = None 15 | self.cross_attention_dim = cross_attention_dim 16 | self.clip_extra_context_tokens = clip_extra_context_tokens 17 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 18 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 19 | 20 | def forward(self, image_embeds): 21 | embeds = image_embeds 22 | clip_extra_context_tokens = self.proj(embeds).reshape( 23 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 24 | ) 25 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 26 | return clip_extra_context_tokens 27 | -------------------------------------------------------------------------------- /memo/models/motion_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import xformers 5 | import xformers.ops 6 | from diffusers.models.attention import FeedForward 7 | from diffusers.models.attention_processor import Attention 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from einops import rearrange, repeat 10 | from torch import nn 11 | 12 | from memo.models.attention import zero_module 13 | from memo.models.attention_processor import ( 14 | MemoryLinearAttnProcessor, 15 | ) 16 | 17 | 18 | class PositionalEncoding(nn.Module): 19 | def __init__(self, d_model, dropout=0.0, max_len=24): 20 | super().__init__() 21 | self.dropout = nn.Dropout(p=dropout) 22 | position = torch.arange(max_len).unsqueeze(1) 23 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 24 | pe = torch.zeros(1, max_len, d_model) 25 | pe[0, :, 0::2] = torch.sin(position * div_term) 26 | pe[0, :, 1::2] = torch.cos(position * div_term) 27 | self.register_buffer("pe", pe) 28 | 29 | def forward(self, x, offset=0): 30 | x = x + self.pe[:, offset : offset + x.size(1)] 31 | return self.dropout(x) 32 | 33 | 34 | class MemoryLinearAttnTemporalModule(nn.Module): 35 | def __init__( 36 | self, 37 | in_channels, 38 | num_attention_heads=8, 39 | num_transformer_block=2, 40 | attention_block_types=("Temporal_Self", "Temporal_Self"), 41 | temporal_position_encoding=False, 42 | temporal_position_encoding_max_len=24, 43 | temporal_attention_dim_div=1, 44 | zero_initialize=True, 45 | ): 46 | super().__init__() 47 | 48 | self.temporal_transformer = TemporalLinearAttnTransformer( 49 | in_channels=in_channels, 50 | num_attention_heads=num_attention_heads, 51 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 52 | num_layers=num_transformer_block, 53 | attention_block_types=attention_block_types, 54 | temporal_position_encoding=temporal_position_encoding, 55 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 56 | ) 57 | 58 | if zero_initialize: 59 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 60 | 61 | def forward( 62 | self, 63 | hidden_states, 64 | motion_frames, 65 | encoder_hidden_states, 66 | is_new_audio=True, 67 | update_past_memory=False, 68 | ): 69 | hidden_states = self.temporal_transformer( 70 | hidden_states, 71 | motion_frames, 72 | encoder_hidden_states, 73 | is_new_audio=is_new_audio, 74 | update_past_memory=update_past_memory, 75 | ) 76 | 77 | output = hidden_states 78 | return output 79 | 80 | 81 | class TemporalLinearAttnTransformer(nn.Module): 82 | def __init__( 83 | self, 84 | in_channels, 85 | num_attention_heads, 86 | attention_head_dim, 87 | num_layers, 88 | attention_block_types=( 89 | "Temporal_Self", 90 | "Temporal_Self", 91 | ), 92 | dropout=0.0, 93 | norm_num_groups=32, 94 | cross_attention_dim=768, 95 | activation_fn="geglu", 96 | attention_bias=False, 97 | upcast_attention=False, 98 | temporal_position_encoding=False, 99 | temporal_position_encoding_max_len=24, 100 | ): 101 | super().__init__() 102 | 103 | inner_dim = num_attention_heads * attention_head_dim 104 | 105 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 106 | self.proj_in = nn.Linear(in_channels, inner_dim) 107 | 108 | self.transformer_blocks = nn.ModuleList( 109 | [ 110 | TemporalLinearAttnTransformerBlock( 111 | dim=inner_dim, 112 | num_attention_heads=num_attention_heads, 113 | attention_head_dim=attention_head_dim, 114 | attention_block_types=attention_block_types, 115 | dropout=dropout, 116 | cross_attention_dim=cross_attention_dim, 117 | activation_fn=activation_fn, 118 | attention_bias=attention_bias, 119 | upcast_attention=upcast_attention, 120 | temporal_position_encoding=temporal_position_encoding, 121 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 122 | ) 123 | for _ in range(num_layers) 124 | ] 125 | ) 126 | self.proj_out = nn.Linear(inner_dim, in_channels) 127 | 128 | def forward( 129 | self, 130 | hidden_states, 131 | motion_frames, 132 | encoder_hidden_states=None, 133 | is_new_audio=True, 134 | update_past_memory=False, 135 | ): 136 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 137 | video_length = hidden_states.shape[2] 138 | n_motion_frames = motion_frames.shape[2] 139 | 140 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 141 | with torch.no_grad(): 142 | motion_frames = rearrange(motion_frames, "b c f h w -> (b f) c h w") 143 | 144 | batch, _, height, weight = hidden_states.shape 145 | residual = hidden_states 146 | 147 | hidden_states = self.norm(hidden_states) 148 | with torch.no_grad(): 149 | motion_frames = self.norm(motion_frames) 150 | 151 | inner_dim = hidden_states.shape[1] 152 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 153 | hidden_states = self.proj_in(hidden_states) 154 | 155 | with torch.no_grad(): 156 | ( 157 | motion_frames_batch, 158 | motion_frames_inner_dim, 159 | motion_frames_height, 160 | motion_frames_weight, 161 | ) = motion_frames.shape 162 | 163 | motion_frames = motion_frames.permute(0, 2, 3, 1).reshape( 164 | motion_frames_batch, 165 | motion_frames_height * motion_frames_weight, 166 | motion_frames_inner_dim, 167 | ) 168 | motion_frames = self.proj_in(motion_frames) 169 | 170 | # Transformer Blocks 171 | for block in self.transformer_blocks: 172 | hidden_states = block( 173 | hidden_states, 174 | motion_frames, 175 | encoder_hidden_states=encoder_hidden_states, 176 | video_length=video_length, 177 | n_motion_frames=n_motion_frames, 178 | is_new_audio=is_new_audio, 179 | update_past_memory=update_past_memory, 180 | ) 181 | 182 | # output 183 | hidden_states = self.proj_out(hidden_states) 184 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 185 | 186 | output = hidden_states + residual 187 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 188 | 189 | return output 190 | 191 | 192 | class TemporalLinearAttnTransformerBlock(nn.Module): 193 | def __init__( 194 | self, 195 | dim, 196 | num_attention_heads, 197 | attention_head_dim, 198 | attention_block_types=( 199 | "Temporal_Self", 200 | "Temporal_Self", 201 | ), 202 | dropout=0.0, 203 | cross_attention_dim=768, 204 | activation_fn="geglu", 205 | attention_bias=False, 206 | upcast_attention=False, 207 | temporal_position_encoding=False, 208 | temporal_position_encoding_max_len=24, 209 | ): 210 | super().__init__() 211 | 212 | attention_blocks = [] 213 | norms = [] 214 | 215 | for block_name in attention_block_types: 216 | attention_blocks.append( 217 | MemoryLinearAttention( 218 | attention_mode=block_name.split("_", maxsplit=1)[0], 219 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 220 | query_dim=dim, 221 | heads=num_attention_heads, 222 | dim_head=attention_head_dim, 223 | dropout=dropout, 224 | bias=attention_bias, 225 | upcast_attention=upcast_attention, 226 | temporal_position_encoding=temporal_position_encoding, 227 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 228 | ) 229 | ) 230 | norms.append(nn.LayerNorm(dim)) 231 | 232 | self.attention_blocks = nn.ModuleList(attention_blocks) 233 | self.norms = nn.ModuleList(norms) 234 | 235 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 236 | self.ff_norm = nn.LayerNorm(dim) 237 | 238 | def forward( 239 | self, 240 | hidden_states, 241 | motion_frames, 242 | encoder_hidden_states=None, 243 | video_length=None, 244 | n_motion_frames=None, 245 | is_new_audio=True, 246 | update_past_memory=False, 247 | ): 248 | for attention_block, norm in zip(self.attention_blocks, self.norms): 249 | norm_hidden_states = norm(hidden_states) 250 | with torch.no_grad(): 251 | norm_motion_frames = norm(motion_frames) 252 | hidden_states = ( 253 | attention_block( 254 | norm_hidden_states, 255 | norm_motion_frames, 256 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 257 | video_length=video_length, 258 | n_motion_frames=n_motion_frames, 259 | is_new_audio=is_new_audio, 260 | update_past_memory=update_past_memory, 261 | ) 262 | + hidden_states 263 | ) 264 | 265 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 266 | 267 | output = hidden_states 268 | return output 269 | 270 | 271 | class MemoryLinearAttention(Attention): 272 | def __init__( 273 | self, 274 | *args, 275 | attention_mode=None, 276 | temporal_position_encoding=False, 277 | temporal_position_encoding_max_len=24, 278 | **kwargs, 279 | ): 280 | super().__init__(*args, **kwargs) 281 | assert attention_mode == "Temporal" 282 | 283 | self.attention_mode = attention_mode 284 | self.is_cross_attention = kwargs.get("cross_attention_dim") is not None 285 | self.query_dim = kwargs["query_dim"] 286 | self.temporal_position_encoding_max_len = temporal_position_encoding_max_len 287 | self.pos_encoder = ( 288 | PositionalEncoding( 289 | kwargs["query_dim"], 290 | dropout=0.0, 291 | max_len=temporal_position_encoding_max_len, 292 | ) 293 | if (temporal_position_encoding and attention_mode == "Temporal") 294 | else None 295 | ) 296 | 297 | def extra_repr(self): 298 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 299 | 300 | def set_use_memory_efficient_attention_xformers( 301 | self, 302 | use_memory_efficient_attention_xformers: bool, 303 | attention_op=None, 304 | ): 305 | if use_memory_efficient_attention_xformers: 306 | if not is_xformers_available(): 307 | raise ModuleNotFoundError( 308 | ( 309 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 310 | " xformers" 311 | ), 312 | name="xformers", 313 | ) 314 | 315 | if not torch.cuda.is_available(): 316 | raise ValueError( 317 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 318 | " only available for GPU " 319 | ) 320 | 321 | try: 322 | # Make sure we can run the memory efficient attention 323 | _ = xformers.ops.memory_efficient_attention( 324 | torch.randn((1, 2, 40), device="cuda"), 325 | torch.randn((1, 2, 40), device="cuda"), 326 | torch.randn((1, 2, 40), device="cuda"), 327 | ) 328 | except Exception as e: 329 | raise e 330 | processor = MemoryLinearAttnProcessor() 331 | else: 332 | processor = MemoryLinearAttnProcessor() 333 | 334 | self.set_processor(processor) 335 | 336 | def forward( 337 | self, 338 | hidden_states, 339 | motion_frames, 340 | encoder_hidden_states=None, 341 | attention_mask=None, 342 | video_length=None, 343 | n_motion_frames=None, 344 | is_new_audio=True, 345 | update_past_memory=False, 346 | **cross_attention_kwargs, 347 | ): 348 | if self.attention_mode == "Temporal": 349 | d = hidden_states.shape[1] 350 | hidden_states = rearrange( 351 | hidden_states, 352 | "(b f) d c -> (b d) f c", 353 | f=video_length, 354 | ) 355 | 356 | if self.pos_encoder is not None: 357 | hidden_states = self.pos_encoder(hidden_states) 358 | 359 | with torch.no_grad(): 360 | motion_frames = rearrange(motion_frames, "(b f) d c -> (b d) f c", f=n_motion_frames) 361 | 362 | encoder_hidden_states = ( 363 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) 364 | if encoder_hidden_states is not None 365 | else encoder_hidden_states 366 | ) 367 | 368 | else: 369 | raise NotImplementedError 370 | 371 | hidden_states = self.processor( 372 | self, 373 | hidden_states, 374 | motion_frames, 375 | encoder_hidden_states=encoder_hidden_states, 376 | attention_mask=attention_mask, 377 | n_motion_frames=n_motion_frames, 378 | is_new_audio=is_new_audio, 379 | update_past_memory=update_past_memory, 380 | **cross_attention_kwargs, 381 | ) 382 | 383 | if self.attention_mode == "Temporal": 384 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 385 | 386 | return hidden_states 387 | -------------------------------------------------------------------------------- /memo/models/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class FP32LayerNorm(nn.LayerNorm): 7 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 8 | origin_dtype = inputs.dtype 9 | return F.layer_norm( 10 | inputs.float(), 11 | self.normalized_shape, 12 | self.weight.float() if self.weight is not None else None, 13 | self.bias.float() if self.bias is not None else None, 14 | self.eps, 15 | ).to(origin_dtype) 16 | -------------------------------------------------------------------------------- /memo/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | 7 | class InflatedConv3d(nn.Conv2d): 8 | def forward(self, x): 9 | video_length = x.shape[2] 10 | 11 | x = rearrange(x, "b c f h w -> (b f) c h w") 12 | x = super().forward(x) 13 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 14 | 15 | return x 16 | 17 | 18 | class InflatedGroupNorm(nn.GroupNorm): 19 | def forward(self, x): 20 | video_length = x.shape[2] 21 | 22 | x = rearrange(x, "b c f h w -> (b f) c h w") 23 | x = super().forward(x) 24 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 25 | 26 | return x 27 | 28 | 29 | class Upsample3D(nn.Module): 30 | def __init__( 31 | self, 32 | channels, 33 | use_conv=False, 34 | use_conv_transpose=False, 35 | out_channels=None, 36 | name="conv", 37 | ): 38 | super().__init__() 39 | self.channels = channels 40 | self.out_channels = out_channels or channels 41 | self.use_conv = use_conv 42 | self.use_conv_transpose = use_conv_transpose 43 | self.name = name 44 | 45 | if use_conv_transpose: 46 | raise NotImplementedError 47 | if use_conv: 48 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 49 | 50 | def forward(self, hidden_states, output_size=None): 51 | assert hidden_states.shape[1] == self.channels 52 | 53 | if self.use_conv_transpose: 54 | raise NotImplementedError 55 | 56 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 57 | dtype = hidden_states.dtype 58 | if dtype == torch.bfloat16: 59 | hidden_states = hidden_states.to(torch.float32) 60 | 61 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 62 | if hidden_states.shape[0] >= 64: 63 | hidden_states = hidden_states.contiguous() 64 | 65 | # if `output_size` is passed we force the interpolation output 66 | # size and do not make use of `scale_factor=2` 67 | if output_size is None: 68 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 69 | else: 70 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 71 | 72 | # If the input is bfloat16, we cast back to bfloat16 73 | if dtype == torch.bfloat16: 74 | hidden_states = hidden_states.to(dtype) 75 | 76 | hidden_states = self.conv(hidden_states) 77 | 78 | return hidden_states 79 | 80 | 81 | class Downsample3D(nn.Module): 82 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 83 | super().__init__() 84 | self.channels = channels 85 | self.out_channels = out_channels or channels 86 | self.use_conv = use_conv 87 | self.padding = padding 88 | stride = 2 89 | self.name = name 90 | 91 | if use_conv: 92 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 93 | else: 94 | raise NotImplementedError 95 | 96 | def forward(self, hidden_states): 97 | assert hidden_states.shape[1] == self.channels 98 | if self.use_conv and self.padding == 0: 99 | raise NotImplementedError 100 | 101 | assert hidden_states.shape[1] == self.channels 102 | hidden_states = self.conv(hidden_states) 103 | 104 | return hidden_states 105 | 106 | 107 | class ResnetBlock3D(nn.Module): 108 | def __init__( 109 | self, 110 | *, 111 | in_channels, 112 | out_channels=None, 113 | conv_shortcut=False, 114 | dropout=0.0, 115 | temb_channels=512, 116 | groups=32, 117 | groups_out=None, 118 | pre_norm=True, 119 | eps=1e-6, 120 | non_linearity="swish", 121 | time_embedding_norm="default", 122 | output_scale_factor=1.0, 123 | use_in_shortcut=None, 124 | use_inflated_groupnorm=None, 125 | ): 126 | super().__init__() 127 | self.pre_norm = pre_norm 128 | self.pre_norm = True 129 | self.in_channels = in_channels 130 | out_channels = in_channels if out_channels is None else out_channels 131 | self.out_channels = out_channels 132 | self.use_conv_shortcut = conv_shortcut 133 | self.time_embedding_norm = time_embedding_norm 134 | self.output_scale_factor = output_scale_factor 135 | 136 | if groups_out is None: 137 | groups_out = groups 138 | 139 | assert use_inflated_groupnorm is not None 140 | if use_inflated_groupnorm: 141 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 142 | else: 143 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 144 | 145 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 146 | 147 | if temb_channels is not None: 148 | if self.time_embedding_norm == "default": 149 | time_emb_proj_out_channels = out_channels 150 | elif self.time_embedding_norm == "scale_shift": 151 | time_emb_proj_out_channels = out_channels * 2 152 | else: 153 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 154 | 155 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 156 | else: 157 | self.time_emb_proj = None 158 | 159 | if use_inflated_groupnorm: 160 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 161 | else: 162 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 163 | self.dropout = torch.nn.Dropout(dropout) 164 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 165 | 166 | if non_linearity == "swish": 167 | self.nonlinearity = F.silu() 168 | elif non_linearity == "mish": 169 | self.nonlinearity = Mish() 170 | elif non_linearity == "silu": 171 | self.nonlinearity = nn.SiLU() 172 | 173 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 174 | 175 | self.conv_shortcut = None 176 | if self.use_in_shortcut: 177 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 178 | 179 | def forward(self, input_tensor, temb): 180 | hidden_states = input_tensor 181 | 182 | hidden_states = self.norm1(hidden_states) 183 | hidden_states = self.nonlinearity(hidden_states) 184 | 185 | hidden_states = self.conv1(hidden_states) 186 | 187 | if temb is not None: 188 | if temb.dim() == 3: 189 | temb = self.time_emb_proj(self.nonlinearity(temb)) 190 | temb = temb.transpose(1, 2).unsqueeze(-1).unsqueeze(-1) 191 | else: 192 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 193 | 194 | if temb is not None and self.time_embedding_norm == "default": 195 | hidden_states = hidden_states + temb 196 | 197 | hidden_states = self.norm2(hidden_states) 198 | 199 | if temb is not None and self.time_embedding_norm == "scale_shift": 200 | scale, shift = torch.chunk(temb, 2, dim=1) 201 | hidden_states = hidden_states * (1 + scale) + shift 202 | 203 | hidden_states = self.nonlinearity(hidden_states) 204 | 205 | hidden_states = self.dropout(hidden_states) 206 | hidden_states = self.conv2(hidden_states) 207 | 208 | if self.conv_shortcut is not None: 209 | input_tensor = self.conv_shortcut(input_tensor) 210 | 211 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 212 | 213 | return output_tensor 214 | 215 | 216 | class Mish(torch.nn.Module): 217 | def forward(self, hidden_states): 218 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 219 | -------------------------------------------------------------------------------- /memo/models/transformer_2d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 7 | from diffusers.models.modeling_utils import ModelMixin 8 | from diffusers.models.normalization import AdaLayerNormSingle 9 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version 10 | from torch import nn 11 | 12 | from memo.models.attention import BasicTransformerBlock 13 | 14 | 15 | @dataclass 16 | class Transformer2DModelOutput(BaseOutput): 17 | sample: torch.FloatTensor 18 | ref_feature_list: list[torch.FloatTensor] 19 | 20 | 21 | class Transformer2DModel(ModelMixin, ConfigMixin): 22 | _supports_gradient_checkpointing = True 23 | 24 | @register_to_config 25 | def __init__( 26 | self, 27 | num_attention_heads: int = 16, 28 | attention_head_dim: int = 88, 29 | in_channels: Optional[int] = None, 30 | out_channels: Optional[int] = None, 31 | num_layers: int = 1, 32 | dropout: float = 0.0, 33 | norm_num_groups: int = 32, 34 | cross_attention_dim: Optional[int] = None, 35 | attention_bias: bool = False, 36 | num_vector_embeds: Optional[int] = None, 37 | patch_size: Optional[int] = None, 38 | activation_fn: str = "geglu", 39 | num_embeds_ada_norm: Optional[int] = None, 40 | use_linear_projection: bool = False, 41 | only_cross_attention: bool = False, 42 | double_self_attention: bool = False, 43 | upcast_attention: bool = False, 44 | norm_type: str = "layer_norm", 45 | norm_elementwise_affine: bool = True, 46 | norm_eps: float = 1e-5, 47 | attention_type: str = "default", 48 | is_final_block: bool = False, 49 | ): 50 | super().__init__() 51 | self.use_linear_projection = use_linear_projection 52 | self.num_attention_heads = num_attention_heads 53 | self.attention_head_dim = attention_head_dim 54 | self.is_final_block = is_final_block 55 | inner_dim = num_attention_heads * attention_head_dim 56 | 57 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 58 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 59 | 60 | # 1. Transformer2DModel can process both standard continuous images of 61 | # shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of 62 | # shape `(batch_size, num_image_vectors)` 63 | # Define whether input is continuous or discrete depending on configuration 64 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 65 | self.is_input_vectorized = num_vector_embeds is not None 66 | self.is_input_patches = in_channels is not None and patch_size is not None 67 | 68 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 69 | deprecation_message = ( 70 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 71 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 72 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 73 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 74 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 75 | ) 76 | deprecate( 77 | "norm_type!=num_embeds_ada_norm", 78 | "1.0.0", 79 | deprecation_message, 80 | standard_warn=False, 81 | ) 82 | norm_type = "ada_norm" 83 | 84 | if self.is_input_continuous and self.is_input_vectorized: 85 | raise ValueError( 86 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 87 | " sure that either `in_channels` or `num_vector_embeds` is None." 88 | ) 89 | 90 | if self.is_input_vectorized and self.is_input_patches: 91 | raise ValueError( 92 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 93 | " sure that either `num_vector_embeds` or `num_patches` is None." 94 | ) 95 | 96 | if not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 97 | raise ValueError( 98 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 99 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 100 | ) 101 | 102 | # 2. Define input layers 103 | self.in_channels = in_channels 104 | 105 | self.norm = torch.nn.GroupNorm( 106 | num_groups=norm_num_groups, 107 | num_channels=in_channels, 108 | eps=1e-6, 109 | affine=True, 110 | ) 111 | if use_linear_projection: 112 | self.proj_in = linear_cls(in_channels, inner_dim) 113 | else: 114 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 115 | 116 | # 3. Define transformers blocks 117 | self.transformer_blocks = nn.ModuleList( 118 | [ 119 | BasicTransformerBlock( 120 | inner_dim, 121 | num_attention_heads, 122 | attention_head_dim, 123 | dropout=dropout, 124 | cross_attention_dim=cross_attention_dim, 125 | activation_fn=activation_fn, 126 | num_embeds_ada_norm=num_embeds_ada_norm, 127 | attention_bias=attention_bias, 128 | only_cross_attention=only_cross_attention, 129 | double_self_attention=double_self_attention, 130 | upcast_attention=upcast_attention, 131 | norm_type=norm_type, 132 | norm_elementwise_affine=norm_elementwise_affine, 133 | norm_eps=norm_eps, 134 | attention_type=attention_type, 135 | is_final_block=(is_final_block and d == num_layers - 1), 136 | ) 137 | for d in range(num_layers) 138 | ] 139 | ) 140 | 141 | # 4. Define output layers 142 | self.out_channels = in_channels if out_channels is None else out_channels 143 | # TODO: should use out_channels for continuous projections 144 | if not is_final_block: 145 | if use_linear_projection: 146 | self.proj_out = linear_cls(inner_dim, in_channels) 147 | else: 148 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 149 | 150 | # 5. PixArt-Alpha blocks. 151 | self.adaln_single = None 152 | self.use_additional_conditions = False 153 | if norm_type == "ada_norm_single": 154 | self.use_additional_conditions = self.config.sample_size == 128 155 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 156 | # additional conditions until we find better name 157 | self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) 158 | 159 | self.caption_projection = None 160 | 161 | self.gradient_checkpointing = False 162 | 163 | def _set_gradient_checkpointing(self, module, value=False): 164 | if hasattr(module, "gradient_checkpointing"): 165 | module.gradient_checkpointing = value 166 | 167 | def forward( 168 | self, 169 | hidden_states: torch.Tensor, 170 | encoder_hidden_states: Optional[torch.Tensor] = None, 171 | timestep: Optional[torch.LongTensor] = None, 172 | class_labels: Optional[torch.LongTensor] = None, 173 | cross_attention_kwargs: Dict[str, Any] = None, 174 | attention_mask: Optional[torch.Tensor] = None, 175 | encoder_attention_mask: Optional[torch.Tensor] = None, 176 | return_dict: bool = True, 177 | ): 178 | if attention_mask is not None and attention_mask.ndim == 2: 179 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 180 | attention_mask = attention_mask.unsqueeze(1) 181 | 182 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 183 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 184 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 185 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 186 | 187 | # Retrieve lora scale. 188 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 189 | 190 | # 1. Input 191 | batch, _, height, width = hidden_states.shape 192 | residual = hidden_states 193 | 194 | hidden_states = self.norm(hidden_states) 195 | if not self.use_linear_projection: 196 | hidden_states = ( 197 | self.proj_in(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_in(hidden_states) 198 | ) 199 | inner_dim = hidden_states.shape[1] 200 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 201 | else: 202 | inner_dim = hidden_states.shape[1] 203 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 204 | hidden_states = ( 205 | self.proj_in(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_in(hidden_states) 206 | ) 207 | 208 | # 2. Blocks 209 | if self.caption_projection is not None: 210 | batch_size = hidden_states.shape[0] 211 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 212 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 213 | 214 | ref_feature_list = [] 215 | for block in self.transformer_blocks: 216 | if self.training and self.gradient_checkpointing: 217 | 218 | def create_custom_forward(module, return_dict=None): 219 | def custom_forward(*inputs): 220 | if return_dict is not None: 221 | return module(*inputs, return_dict=return_dict) 222 | 223 | return module(*inputs) 224 | 225 | return custom_forward 226 | 227 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 228 | hidden_states, ref_feature = torch.utils.checkpoint.checkpoint( 229 | create_custom_forward(block), 230 | hidden_states, 231 | attention_mask, 232 | encoder_hidden_states, 233 | encoder_attention_mask, 234 | timestep, 235 | cross_attention_kwargs, 236 | class_labels, 237 | **ckpt_kwargs, 238 | ) 239 | else: 240 | hidden_states, ref_feature = block( 241 | hidden_states, # shape [5, 4096, 320] 242 | attention_mask=attention_mask, 243 | encoder_hidden_states=encoder_hidden_states, # shape [1,4,768] 244 | encoder_attention_mask=encoder_attention_mask, 245 | timestep=timestep, 246 | cross_attention_kwargs=cross_attention_kwargs, 247 | class_labels=class_labels, 248 | ) 249 | ref_feature_list.append(ref_feature) 250 | 251 | # 3. Output 252 | output = None 253 | 254 | if self.is_final_block: 255 | if not return_dict: 256 | return (output, ref_feature_list) 257 | 258 | return Transformer2DModelOutput(sample=output, ref_feature_list=ref_feature_list) 259 | 260 | if self.is_input_continuous: 261 | if not self.use_linear_projection: 262 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 263 | hidden_states = ( 264 | self.proj_out(hidden_states, scale=lora_scale) 265 | if not USE_PEFT_BACKEND 266 | else self.proj_out(hidden_states) 267 | ) 268 | else: 269 | hidden_states = ( 270 | self.proj_out(hidden_states, scale=lora_scale) 271 | if not USE_PEFT_BACKEND 272 | else self.proj_out(hidden_states) 273 | ) 274 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 275 | 276 | output = hidden_states + residual 277 | if not return_dict: 278 | return (output, ref_feature_list) 279 | 280 | return Transformer2DModelOutput(sample=output, ref_feature_list=ref_feature_list) 281 | -------------------------------------------------------------------------------- /memo/models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models import ModelMixin 7 | from diffusers.utils import BaseOutput 8 | from einops import rearrange, repeat 9 | from torch import nn 10 | 11 | from memo.models.attention import JointAudioTemporalBasicTransformerBlock, TemporalBasicTransformerBlock 12 | 13 | 14 | def create_custom_forward(module, return_dict=None): 15 | def custom_forward(*inputs): 16 | if return_dict is not None: 17 | return module(*inputs, return_dict=return_dict) 18 | 19 | return module(*inputs) 20 | 21 | return custom_forward 22 | 23 | 24 | @dataclass 25 | class Transformer3DModelOutput(BaseOutput): 26 | sample: torch.FloatTensor 27 | 28 | 29 | class Transformer3DModel(ModelMixin, ConfigMixin): 30 | _supports_gradient_checkpointing = True 31 | 32 | @register_to_config 33 | def __init__( 34 | self, 35 | num_attention_heads: int = 16, 36 | attention_head_dim: int = 88, 37 | in_channels: Optional[int] = None, 38 | num_layers: int = 1, 39 | dropout: float = 0.0, 40 | norm_num_groups: int = 32, 41 | cross_attention_dim: Optional[int] = None, 42 | attention_bias: bool = False, 43 | activation_fn: str = "geglu", 44 | use_linear_projection: bool = False, 45 | only_cross_attention: bool = False, 46 | upcast_attention: bool = False, 47 | unet_use_cross_frame_attention=None, 48 | unet_use_temporal_attention=None, 49 | use_audio_module=False, 50 | depth=0, 51 | unet_block_name=None, 52 | emo_drop_rate=0.3, 53 | is_final_block=False, 54 | ): 55 | super().__init__() 56 | self.use_linear_projection = use_linear_projection 57 | self.num_attention_heads = num_attention_heads 58 | self.attention_head_dim = attention_head_dim 59 | inner_dim = num_attention_heads * attention_head_dim 60 | self.use_audio_module = use_audio_module 61 | # Define input layers 62 | self.in_channels = in_channels 63 | self.is_final_block = is_final_block 64 | 65 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 66 | if use_linear_projection: 67 | self.proj_in = nn.Linear(in_channels, inner_dim) 68 | else: 69 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 70 | 71 | if use_audio_module: 72 | self.transformer_blocks = nn.ModuleList( 73 | [ 74 | JointAudioTemporalBasicTransformerBlock( 75 | dim=inner_dim, 76 | num_attention_heads=num_attention_heads, 77 | attention_head_dim=attention_head_dim, 78 | dropout=dropout, 79 | cross_attention_dim=cross_attention_dim, 80 | activation_fn=activation_fn, 81 | attention_bias=attention_bias, 82 | only_cross_attention=only_cross_attention, 83 | upcast_attention=upcast_attention, 84 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 85 | unet_use_temporal_attention=unet_use_temporal_attention, 86 | depth=depth, 87 | unet_block_name=unet_block_name, 88 | use_ada_layer_norm=True, 89 | emo_drop_rate=emo_drop_rate, 90 | is_final_block=(is_final_block and d == num_layers - 1), 91 | ) 92 | for d in range(num_layers) 93 | ] 94 | ) 95 | else: 96 | self.transformer_blocks = nn.ModuleList( 97 | [ 98 | TemporalBasicTransformerBlock( 99 | inner_dim, 100 | num_attention_heads, 101 | attention_head_dim, 102 | dropout=dropout, 103 | cross_attention_dim=cross_attention_dim, 104 | activation_fn=activation_fn, 105 | attention_bias=attention_bias, 106 | only_cross_attention=only_cross_attention, 107 | upcast_attention=upcast_attention, 108 | ) 109 | for _ in range(num_layers) 110 | ] 111 | ) 112 | 113 | # 4. Define output layers 114 | if use_linear_projection: 115 | self.proj_out = nn.Linear(in_channels, inner_dim) 116 | else: 117 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 118 | 119 | self.gradient_checkpointing = False 120 | 121 | def _set_gradient_checkpointing(self, module, value=False): 122 | if hasattr(module, "gradient_checkpointing"): 123 | module.gradient_checkpointing = value 124 | 125 | def forward( 126 | self, 127 | hidden_states, 128 | ref_img_feature=None, 129 | encoder_hidden_states=None, 130 | attention_mask=None, 131 | timestep=None, 132 | emotion=None, 133 | uc_mask=None, 134 | return_dict: bool = True, 135 | ): 136 | # Input 137 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 138 | video_length = hidden_states.shape[2] 139 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 140 | 141 | if self.use_audio_module: 142 | if encoder_hidden_states.dim() == 4: 143 | encoder_hidden_states = rearrange( 144 | encoder_hidden_states, 145 | "bs f margin dim -> (bs f) margin dim", 146 | ) 147 | else: 148 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 149 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length) 150 | 151 | batch, _, height, weight = hidden_states.shape 152 | residual = hidden_states 153 | if self.use_audio_module: 154 | residual_audio = encoder_hidden_states 155 | 156 | hidden_states = self.norm(hidden_states) 157 | if not self.use_linear_projection: 158 | hidden_states = self.proj_in(hidden_states) 159 | inner_dim = hidden_states.shape[1] 160 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 161 | else: 162 | inner_dim = hidden_states.shape[1] 163 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 164 | hidden_states = self.proj_in(hidden_states) 165 | 166 | # Blocks 167 | for block in self.transformer_blocks: 168 | if self.training and self.gradient_checkpointing: 169 | if isinstance(block, TemporalBasicTransformerBlock): 170 | hidden_states = torch.utils.checkpoint.checkpoint( 171 | create_custom_forward(block), 172 | hidden_states, 173 | ref_img_feature, 174 | None, # attention_mask 175 | encoder_hidden_states, 176 | timestep, 177 | None, # cross_attention_kwargs 178 | video_length, 179 | uc_mask, 180 | ) 181 | elif isinstance(block, JointAudioTemporalBasicTransformerBlock): 182 | ( 183 | hidden_states, 184 | encoder_hidden_states, 185 | ) = torch.utils.checkpoint.checkpoint( 186 | create_custom_forward(block), 187 | hidden_states, 188 | encoder_hidden_states, 189 | attention_mask, 190 | emotion, 191 | ) 192 | else: 193 | hidden_states = torch.utils.checkpoint.checkpoint( 194 | create_custom_forward(block), 195 | hidden_states, 196 | encoder_hidden_states, 197 | timestep, 198 | attention_mask, 199 | video_length, 200 | ) 201 | else: 202 | if isinstance(block, TemporalBasicTransformerBlock): 203 | hidden_states = block( 204 | hidden_states=hidden_states, 205 | ref_img_feature=ref_img_feature, 206 | encoder_hidden_states=encoder_hidden_states, 207 | timestep=timestep, 208 | video_length=video_length, 209 | uc_mask=uc_mask, 210 | ) 211 | elif isinstance(block, JointAudioTemporalBasicTransformerBlock): 212 | hidden_states, encoder_hidden_states = block( 213 | hidden_states, # shape [2, 4096, 320] 214 | encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640] 215 | attention_mask=attention_mask, 216 | emotion=emotion, 217 | ) 218 | else: 219 | hidden_states = block( 220 | hidden_states, # shape [2, 4096, 320] 221 | encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640] 222 | attention_mask=attention_mask, 223 | timestep=timestep, 224 | video_length=video_length, 225 | ) 226 | 227 | # Output 228 | if not self.use_linear_projection: 229 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 230 | hidden_states = self.proj_out(hidden_states) 231 | else: 232 | hidden_states = self.proj_out(hidden_states) 233 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 234 | 235 | output = hidden_states + residual 236 | 237 | if self.use_audio_module and not self.is_final_block: 238 | audio_output = encoder_hidden_states + residual_audio 239 | else: 240 | audio_output = None 241 | 242 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 243 | if not return_dict: 244 | if self.use_audio_module: 245 | return output, audio_output 246 | else: 247 | return output 248 | 249 | if self.use_audio_module: 250 | return output, audio_output 251 | else: 252 | return output 253 | -------------------------------------------------------------------------------- /memo/models/unet_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.checkpoint 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.models.attention_processor import AttentionProcessor 9 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 10 | from diffusers.models.modeling_utils import ModelMixin 11 | from diffusers.utils import BaseOutput, logging 12 | 13 | from memo.models.resnet import InflatedConv3d, InflatedGroupNorm 14 | from memo.models.unet_3d_blocks import ( 15 | UNetMidBlock3DCrossAttn, 16 | get_down_block, 17 | get_up_block, 18 | ) 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | @dataclass 25 | class UNet3DConditionOutput(BaseOutput): 26 | sample: torch.FloatTensor 27 | 28 | 29 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 30 | _supports_gradient_checkpointing = True 31 | 32 | @register_to_config 33 | def __init__( 34 | self, 35 | sample_size: Optional[int] = None, 36 | in_channels: int = 8, 37 | out_channels: int = 8, 38 | flip_sin_to_cos: bool = True, 39 | freq_shift: int = 0, 40 | down_block_types: Tuple[str] = ( 41 | "CrossAttnDownBlock3D", 42 | "CrossAttnDownBlock3D", 43 | "CrossAttnDownBlock3D", 44 | "DownBlock3D", 45 | ), 46 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 47 | up_block_types: Tuple[str] = ( 48 | "UpBlock3D", 49 | "CrossAttnUpBlock3D", 50 | "CrossAttnUpBlock3D", 51 | "CrossAttnUpBlock3D", 52 | ), 53 | only_cross_attention: Union[bool, Tuple[bool]] = False, 54 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 55 | layers_per_block: int = 2, 56 | downsample_padding: int = 1, 57 | mid_block_scale_factor: float = 1, 58 | act_fn: str = "silu", 59 | norm_num_groups: int = 32, 60 | norm_eps: float = 1e-5, 61 | cross_attention_dim: int = 1280, 62 | attention_head_dim: Union[int, Tuple[int]] = 8, 63 | dual_cross_attention: bool = False, 64 | use_linear_projection: bool = False, 65 | class_embed_type: Optional[str] = None, 66 | num_class_embeds: Optional[int] = None, 67 | upcast_attention: bool = False, 68 | resnet_time_scale_shift: str = "default", 69 | use_inflated_groupnorm=False, 70 | # Additional 71 | motion_module_resolutions=(1, 2, 4, 8), 72 | motion_module_kwargs=None, 73 | unet_use_cross_frame_attention=None, 74 | unet_use_temporal_attention=None, 75 | # audio 76 | audio_attention_dim=768, 77 | emo_drop_rate=0.3, 78 | ): 79 | super().__init__() 80 | 81 | self.sample_size = sample_size 82 | time_embed_dim = block_out_channels[0] * 4 83 | 84 | # input 85 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 86 | 87 | # time 88 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 89 | timestep_input_dim = block_out_channels[0] 90 | 91 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 92 | 93 | # class embedding 94 | if class_embed_type is None and num_class_embeds is not None: 95 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 96 | elif class_embed_type == "timestep": 97 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 98 | elif class_embed_type == "identity": 99 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 100 | else: 101 | self.class_embedding = None 102 | 103 | self.down_blocks = nn.ModuleList([]) 104 | self.mid_block = None 105 | self.up_blocks = nn.ModuleList([]) 106 | 107 | if isinstance(only_cross_attention, bool): 108 | only_cross_attention = [only_cross_attention] * len(down_block_types) 109 | 110 | if isinstance(attention_head_dim, int): 111 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 112 | 113 | # down 114 | output_channel = block_out_channels[0] 115 | for i, down_block_type in enumerate(down_block_types): 116 | res = 2**i 117 | input_channel = output_channel 118 | output_channel = block_out_channels[i] 119 | is_final_block = i == len(block_out_channels) - 1 120 | 121 | down_block = get_down_block( 122 | down_block_type, 123 | num_layers=layers_per_block, 124 | in_channels=input_channel, 125 | out_channels=output_channel, 126 | temb_channels=time_embed_dim, 127 | add_downsample=not is_final_block, 128 | resnet_eps=norm_eps, 129 | resnet_act_fn=act_fn, 130 | resnet_groups=norm_num_groups, 131 | cross_attention_dim=cross_attention_dim, 132 | attn_num_head_channels=attention_head_dim[i], 133 | downsample_padding=downsample_padding, 134 | dual_cross_attention=dual_cross_attention, 135 | use_linear_projection=use_linear_projection, 136 | only_cross_attention=only_cross_attention[i], 137 | upcast_attention=upcast_attention, 138 | resnet_time_scale_shift=resnet_time_scale_shift, 139 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 140 | unet_use_temporal_attention=unet_use_temporal_attention, 141 | use_inflated_groupnorm=use_inflated_groupnorm, 142 | use_motion_module=res in motion_module_resolutions, 143 | motion_module_kwargs=motion_module_kwargs, 144 | audio_attention_dim=audio_attention_dim, 145 | depth=i, 146 | emo_drop_rate=emo_drop_rate, 147 | ) 148 | self.down_blocks.append(down_block) 149 | 150 | # mid 151 | if mid_block_type == "UNetMidBlock3DCrossAttn": 152 | self.mid_block = UNetMidBlock3DCrossAttn( 153 | in_channels=block_out_channels[-1], 154 | temb_channels=time_embed_dim, 155 | resnet_eps=norm_eps, 156 | resnet_act_fn=act_fn, 157 | output_scale_factor=mid_block_scale_factor, 158 | resnet_time_scale_shift=resnet_time_scale_shift, 159 | cross_attention_dim=cross_attention_dim, 160 | attn_num_head_channels=attention_head_dim[-1], 161 | resnet_groups=norm_num_groups, 162 | dual_cross_attention=dual_cross_attention, 163 | use_linear_projection=use_linear_projection, 164 | upcast_attention=upcast_attention, 165 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 166 | unet_use_temporal_attention=unet_use_temporal_attention, 167 | use_inflated_groupnorm=use_inflated_groupnorm, 168 | motion_module_kwargs=motion_module_kwargs, 169 | audio_attention_dim=audio_attention_dim, 170 | depth=3, 171 | emo_drop_rate=emo_drop_rate, 172 | ) 173 | else: 174 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 175 | 176 | # count how many layers upsample the videos 177 | self.num_upsamplers = 0 178 | 179 | # up 180 | reversed_block_out_channels = list(reversed(block_out_channels)) 181 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 182 | only_cross_attention = list(reversed(only_cross_attention)) 183 | output_channel = reversed_block_out_channels[0] 184 | for i, up_block_type in enumerate(up_block_types): 185 | res = 2 ** (3 - i) 186 | is_final_block = i == len(block_out_channels) - 1 187 | 188 | prev_output_channel = output_channel 189 | output_channel = reversed_block_out_channels[i] 190 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 191 | 192 | # add upsample block for all BUT final layer 193 | if not is_final_block: 194 | add_upsample = True 195 | self.num_upsamplers += 1 196 | else: 197 | add_upsample = False 198 | 199 | up_block = get_up_block( 200 | up_block_type, 201 | num_layers=layers_per_block + 1, 202 | in_channels=input_channel, 203 | out_channels=output_channel, 204 | prev_output_channel=prev_output_channel, 205 | temb_channels=time_embed_dim, 206 | add_upsample=add_upsample, 207 | resnet_eps=norm_eps, 208 | resnet_act_fn=act_fn, 209 | resnet_groups=norm_num_groups, 210 | cross_attention_dim=cross_attention_dim, 211 | attn_num_head_channels=reversed_attention_head_dim[i], 212 | dual_cross_attention=dual_cross_attention, 213 | use_linear_projection=use_linear_projection, 214 | only_cross_attention=only_cross_attention[i], 215 | upcast_attention=upcast_attention, 216 | resnet_time_scale_shift=resnet_time_scale_shift, 217 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 218 | unet_use_temporal_attention=unet_use_temporal_attention, 219 | use_inflated_groupnorm=use_inflated_groupnorm, 220 | use_motion_module=res in motion_module_resolutions, 221 | motion_module_kwargs=motion_module_kwargs, 222 | audio_attention_dim=audio_attention_dim, 223 | depth=3 - i, 224 | emo_drop_rate=emo_drop_rate, 225 | is_final_block=is_final_block, 226 | ) 227 | self.up_blocks.append(up_block) 228 | prev_output_channel = output_channel 229 | 230 | # out 231 | if use_inflated_groupnorm: 232 | self.conv_norm_out = InflatedGroupNorm( 233 | num_channels=block_out_channels[0], 234 | num_groups=norm_num_groups, 235 | eps=norm_eps, 236 | ) 237 | else: 238 | self.conv_norm_out = nn.GroupNorm( 239 | num_channels=block_out_channels[0], 240 | num_groups=norm_num_groups, 241 | eps=norm_eps, 242 | ) 243 | self.conv_act = nn.SiLU() 244 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 245 | 246 | @property 247 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors 248 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 249 | r""" 250 | Returns: 251 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 252 | indexed by its weight name. 253 | """ 254 | # set recursively 255 | processors = {} 256 | 257 | def fn_recursive_add_processors( 258 | name: str, 259 | module: torch.nn.Module, 260 | processors: Dict[str, AttentionProcessor], 261 | ): 262 | if hasattr(module, "set_processor"): 263 | processors[f"{name}.processor"] = module.processor 264 | 265 | for sub_name, child in module.named_children(): 266 | if "temporal_transformer" not in sub_name: 267 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 268 | 269 | return processors 270 | 271 | for name, module in self.named_children(): 272 | if "temporal_transformer" not in name: 273 | fn_recursive_add_processors(name, module, processors) 274 | 275 | return processors 276 | 277 | def set_attention_slice(self, slice_size): 278 | r""" 279 | Enable sliced attention computation. 280 | 281 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 282 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 283 | 284 | Args: 285 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 286 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 287 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 288 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 289 | must be a multiple of `slice_size`. 290 | """ 291 | sliceable_head_dims = [] 292 | 293 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 294 | if hasattr(module, "set_attention_slice"): 295 | sliceable_head_dims.append(module.sliceable_head_dim) 296 | 297 | for child in module.children(): 298 | fn_recursive_retrieve_slicable_dims(child) 299 | 300 | # retrieve number of attention layers 301 | for module in self.children(): 302 | fn_recursive_retrieve_slicable_dims(module) 303 | 304 | num_slicable_layers = len(sliceable_head_dims) 305 | 306 | if slice_size == "auto": 307 | # half the attention head size is usually a good trade-off between 308 | # speed and memory 309 | slice_size = [dim // 2 for dim in sliceable_head_dims] 310 | elif slice_size == "max": 311 | # make smallest slice possible 312 | slice_size = num_slicable_layers * [1] 313 | 314 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 315 | 316 | if len(slice_size) != len(sliceable_head_dims): 317 | raise ValueError( 318 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 319 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 320 | ) 321 | 322 | for i, size in enumerate(slice_size): 323 | dim = sliceable_head_dims[i] 324 | if size is not None and size > dim: 325 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 326 | 327 | # Recursively walk through all the children. 328 | # Any children which exposes the set_attention_slice method 329 | # gets the message 330 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 331 | if hasattr(module, "set_attention_slice"): 332 | module.set_attention_slice(slice_size.pop()) 333 | 334 | for child in module.children(): 335 | fn_recursive_set_attention_slice(child, slice_size) 336 | 337 | reversed_slice_size = list(reversed(slice_size)) 338 | for module in self.children(): 339 | fn_recursive_set_attention_slice(module, reversed_slice_size) 340 | 341 | def _set_gradient_checkpointing(self, module, value=False): 342 | if hasattr(module, "gradient_checkpointing"): 343 | module.gradient_checkpointing = value 344 | 345 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor 346 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 347 | r""" 348 | Sets the attention processor to use to compute attention. 349 | 350 | Parameters: 351 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 352 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 353 | for **all** `Attention` layers. 354 | 355 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 356 | processor. This is strongly recommended when setting trainable attention processors. 357 | 358 | """ 359 | count = len(self.attn_processors.keys()) 360 | 361 | if isinstance(processor, dict) and len(processor) != count: 362 | raise ValueError( 363 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 364 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 365 | ) 366 | 367 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 368 | if hasattr(module, "set_processor"): 369 | if not isinstance(processor, dict): 370 | module.set_processor(processor) 371 | else: 372 | module.set_processor(processor.pop(f"{name}.processor")) 373 | 374 | for sub_name, child in module.named_children(): 375 | if "temporal_transformer" not in sub_name: 376 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 377 | 378 | for name, module in self.named_children(): 379 | if "temporal_transformer" not in name: 380 | fn_recursive_attn_processor(name, module, processor) 381 | 382 | def forward( 383 | self, 384 | sample: torch.FloatTensor, 385 | ref_features: dict, 386 | timestep: Union[torch.Tensor, float, int, list], 387 | encoder_hidden_states: torch.Tensor, 388 | audio_embedding: Optional[torch.Tensor] = None, 389 | audio_emotion: Optional[torch.Tensor] = None, 390 | class_labels: Optional[torch.Tensor] = None, 391 | mask_cond_fea: Optional[torch.Tensor] = None, 392 | attention_mask: Optional[torch.Tensor] = None, 393 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 394 | mid_block_additional_residual: Optional[torch.Tensor] = None, 395 | uc_mask: Optional[torch.Tensor] = None, 396 | return_dict: bool = True, 397 | is_new_audio=True, 398 | update_past_memory=False, 399 | ) -> Union[UNet3DConditionOutput, Tuple]: 400 | # By default samples have to be AT least a multiple of the overall upsampling factor. 401 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 402 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 403 | # on the fly if necessary. 404 | default_overall_up_factor = 2**self.num_upsamplers 405 | 406 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 407 | forward_upsample_size = False 408 | upsample_size = None 409 | 410 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 411 | logger.info("Forward upsample size to force interpolation output size.") 412 | forward_upsample_size = True 413 | 414 | # prepare attention_mask 415 | if attention_mask is not None: 416 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 417 | attention_mask = attention_mask.unsqueeze(1) 418 | 419 | # center input if necessary 420 | if self.config.center_input_sample: 421 | sample = 2 * sample - 1.0 422 | 423 | # time 424 | timesteps = timestep 425 | if isinstance(timesteps, list): 426 | t_emb_list = [] 427 | for timesteps in timestep: 428 | if not torch.is_tensor(timesteps): 429 | # This would be a good case for the `match` statement (Python 3.10+) 430 | is_mps = sample.device.type == "mps" 431 | if isinstance(timestep, float): 432 | dtype = torch.float32 if is_mps else torch.float64 433 | else: 434 | dtype = torch.int32 if is_mps else torch.int64 435 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 436 | elif len(timesteps.shape) == 0: 437 | timesteps = timesteps[None].to(sample.device) 438 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 439 | timesteps = timesteps.expand(sample.shape[0]) 440 | t_emb = self.time_proj(timesteps) 441 | t_emb_list.append(t_emb) 442 | 443 | t_emb = torch.stack(t_emb_list, dim=1) 444 | else: 445 | if not torch.is_tensor(timesteps): 446 | # This would be a good case for the `match` statement (Python 3.10+) 447 | is_mps = sample.device.type == "mps" 448 | if isinstance(timestep, float): 449 | dtype = torch.float32 if is_mps else torch.float64 450 | else: 451 | dtype = torch.int32 if is_mps else torch.int64 452 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 453 | elif len(timesteps.shape) == 0: 454 | timesteps = timesteps[None].to(sample.device) 455 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 456 | timesteps = timesteps.expand(sample.shape[0]) 457 | t_emb = self.time_proj(timesteps) 458 | 459 | # timesteps does not contain any weights and will always return f32 tensors 460 | # but time_embedding might actually be running in fp16. so we need to cast here. 461 | # there might be better ways to encapsulate this. 462 | t_emb = t_emb.to(dtype=self.dtype) 463 | emb = self.time_embedding(t_emb) 464 | 465 | if self.class_embedding is not None: 466 | if class_labels is None: 467 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 468 | 469 | if self.config.class_embed_type == "timestep": 470 | class_labels = self.time_proj(class_labels) 471 | 472 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 473 | emb = emb + class_emb 474 | 475 | # pre-process 476 | sample = self.conv_in(sample) 477 | if mask_cond_fea is not None: 478 | sample = sample + mask_cond_fea 479 | 480 | # down 481 | down_block_res_samples = (sample,) 482 | for i, downsample_block in enumerate(self.down_blocks): 483 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 484 | sample, res_samples, audio_embedding = downsample_block( 485 | hidden_states=sample, 486 | ref_feature_list=ref_features["down"][i], 487 | temb=emb, 488 | encoder_hidden_states=encoder_hidden_states, 489 | attention_mask=attention_mask, 490 | audio_embedding=audio_embedding, 491 | emotion=audio_emotion, 492 | uc_mask=uc_mask, 493 | is_new_audio=is_new_audio, 494 | update_past_memory=update_past_memory, 495 | ) 496 | else: 497 | sample, res_samples = downsample_block( 498 | hidden_states=sample, 499 | ref_feature_list=ref_features["down"][i], 500 | temb=emb, 501 | encoder_hidden_states=encoder_hidden_states, 502 | is_new_audio=is_new_audio, 503 | update_past_memory=update_past_memory, 504 | ) 505 | 506 | down_block_res_samples += res_samples 507 | 508 | if down_block_additional_residuals is not None: 509 | new_down_block_res_samples = () 510 | 511 | for down_block_res_sample, down_block_additional_residual in zip( 512 | down_block_res_samples, down_block_additional_residuals 513 | ): 514 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 515 | new_down_block_res_samples += (down_block_res_sample,) 516 | 517 | down_block_res_samples = new_down_block_res_samples 518 | 519 | # mid 520 | sample, audio_embedding = self.mid_block( 521 | sample, 522 | ref_feature_list=ref_features["mid"][0], 523 | temb=emb, 524 | encoder_hidden_states=encoder_hidden_states, 525 | attention_mask=attention_mask, 526 | audio_embedding=audio_embedding, 527 | emotion=audio_emotion, 528 | uc_mask=uc_mask, 529 | is_new_audio=is_new_audio, 530 | update_past_memory=update_past_memory, 531 | ) 532 | 533 | if mid_block_additional_residual is not None: 534 | sample = sample + mid_block_additional_residual 535 | 536 | # up 537 | for i, upsample_block in enumerate(self.up_blocks): 538 | is_final_block = i == len(self.up_blocks) - 1 539 | 540 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 541 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 542 | 543 | # if we have not reached the final block and need to forward the 544 | # upsample size, we do it here 545 | if not is_final_block and forward_upsample_size: 546 | upsample_size = down_block_res_samples[-1].shape[2:] 547 | 548 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 549 | sample, audio_embedding = upsample_block( 550 | hidden_states=sample, 551 | ref_feature_list=ref_features["up"][i], 552 | temb=emb, 553 | res_hidden_states_tuple=res_samples, 554 | encoder_hidden_states=encoder_hidden_states, 555 | upsample_size=upsample_size, 556 | attention_mask=attention_mask, 557 | audio_embedding=audio_embedding, 558 | emotion=audio_emotion, 559 | uc_mask=uc_mask, 560 | is_new_audio=is_new_audio, 561 | update_past_memory=update_past_memory, 562 | ) 563 | else: 564 | sample = upsample_block( 565 | hidden_states=sample, 566 | ref_feature_list=ref_features["up"][i], 567 | temb=emb, 568 | res_hidden_states_tuple=res_samples, 569 | upsample_size=upsample_size, 570 | encoder_hidden_states=encoder_hidden_states, 571 | is_new_audio=is_new_audio, 572 | update_past_memory=update_past_memory, 573 | ) 574 | 575 | # post-process 576 | sample = self.conv_norm_out(sample) 577 | sample = self.conv_act(sample) 578 | sample = self.conv_out(sample) 579 | 580 | if not return_dict: 581 | return (sample,) 582 | 583 | return UNet3DConditionOutput(sample=sample) 584 | -------------------------------------------------------------------------------- /memo/models/wav2vec.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from transformers import Wav2Vec2Model 3 | from transformers.modeling_outputs import BaseModelOutput 4 | 5 | 6 | class Wav2VecModel(Wav2Vec2Model): 7 | def forward( 8 | self, 9 | input_values, 10 | seq_len, 11 | attention_mask=None, 12 | mask_time_indices=None, 13 | output_attentions=None, 14 | output_hidden_states=None, 15 | return_dict=None, 16 | ): 17 | self.config.output_attentions = True 18 | 19 | output_hidden_states = ( 20 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 21 | ) 22 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 23 | 24 | extract_features = self.feature_extractor(input_values) 25 | extract_features = extract_features.transpose(1, 2) 26 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 27 | 28 | if attention_mask is not None: 29 | # compute reduced attention_mask corresponding to feature vectors 30 | attention_mask = self._get_feature_vector_attention_mask( 31 | extract_features.shape[1], attention_mask, add_adapter=False 32 | ) 33 | 34 | hidden_states, extract_features = self.feature_projection(extract_features) 35 | hidden_states = self._mask_hidden_states( 36 | hidden_states, 37 | mask_time_indices=mask_time_indices, 38 | attention_mask=attention_mask, 39 | ) 40 | 41 | encoder_outputs = self.encoder( 42 | hidden_states, 43 | attention_mask=attention_mask, 44 | output_attentions=output_attentions, 45 | output_hidden_states=output_hidden_states, 46 | return_dict=return_dict, 47 | ) 48 | 49 | hidden_states = encoder_outputs[0] 50 | 51 | if self.adapter is not None: 52 | hidden_states = self.adapter(hidden_states) 53 | 54 | if not return_dict: 55 | return (hidden_states,) + encoder_outputs[1:] 56 | return BaseModelOutput( 57 | last_hidden_state=hidden_states, 58 | hidden_states=encoder_outputs.hidden_states, 59 | attentions=encoder_outputs.attentions, 60 | ) 61 | 62 | def feature_extract( 63 | self, 64 | input_values, 65 | seq_len, 66 | ): 67 | extract_features = self.feature_extractor(input_values) 68 | extract_features = extract_features.transpose(1, 2) 69 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 70 | 71 | return extract_features 72 | 73 | def encode( 74 | self, 75 | extract_features, 76 | attention_mask=None, 77 | mask_time_indices=None, 78 | output_attentions=None, 79 | output_hidden_states=None, 80 | return_dict=None, 81 | ): 82 | self.config.output_attentions = True 83 | 84 | output_hidden_states = ( 85 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 86 | ) 87 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 88 | 89 | if attention_mask is not None: 90 | # compute reduced attention_mask corresponding to feature vectors 91 | attention_mask = self._get_feature_vector_attention_mask( 92 | extract_features.shape[1], attention_mask, add_adapter=False 93 | ) 94 | 95 | hidden_states, extract_features = self.feature_projection(extract_features) 96 | hidden_states = self._mask_hidden_states( 97 | hidden_states, 98 | mask_time_indices=mask_time_indices, 99 | attention_mask=attention_mask, 100 | ) 101 | 102 | encoder_outputs = self.encoder( 103 | hidden_states, 104 | attention_mask=attention_mask, 105 | output_attentions=output_attentions, 106 | output_hidden_states=output_hidden_states, 107 | return_dict=return_dict, 108 | ) 109 | 110 | hidden_states = encoder_outputs[0] 111 | 112 | if self.adapter is not None: 113 | hidden_states = self.adapter(hidden_states) 114 | 115 | if not return_dict: 116 | return (hidden_states,) + encoder_outputs[1:] 117 | return BaseModelOutput( 118 | last_hidden_state=hidden_states, 119 | hidden_states=encoder_outputs.hidden_states, 120 | attentions=encoder_outputs.attentions, 121 | ) 122 | 123 | 124 | def linear_interpolation(features, seq_len): 125 | features = features.transpose(1, 2) 126 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode="linear") 127 | return output_features.transpose(1, 2) 128 | -------------------------------------------------------------------------------- /memo/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/memo/pipelines/__init__.py -------------------------------------------------------------------------------- /memo/pipelines/video_pipeline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | from typing import Callable, List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import ( 8 | DDIMScheduler, 9 | DiffusionPipeline, 10 | DPMSolverMultistepScheduler, 11 | EulerAncestralDiscreteScheduler, 12 | EulerDiscreteScheduler, 13 | LMSDiscreteScheduler, 14 | PNDMScheduler, 15 | ) 16 | from diffusers.image_processor import VaeImageProcessor 17 | from diffusers.utils import BaseOutput 18 | from diffusers.utils.torch_utils import randn_tensor 19 | from einops import rearrange 20 | 21 | 22 | @dataclass 23 | class VideoPipelineOutput(BaseOutput): 24 | videos: Union[torch.Tensor, np.ndarray] 25 | 26 | 27 | class VideoPipeline(DiffusionPipeline): 28 | def __init__( 29 | self, 30 | vae, 31 | reference_net, 32 | diffusion_net, 33 | image_proj, 34 | scheduler: Union[ 35 | DDIMScheduler, 36 | PNDMScheduler, 37 | LMSDiscreteScheduler, 38 | EulerDiscreteScheduler, 39 | EulerAncestralDiscreteScheduler, 40 | DPMSolverMultistepScheduler, 41 | ], 42 | ) -> None: 43 | super().__init__() 44 | 45 | self.register_modules( 46 | vae=vae, 47 | reference_net=reference_net, 48 | diffusion_net=diffusion_net, 49 | scheduler=scheduler, 50 | image_proj=image_proj, 51 | ) 52 | 53 | self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1) 54 | 55 | self.ref_image_processor = VaeImageProcessor( 56 | vae_scale_factor=self.vae_scale_factor, 57 | do_convert_rgb=True, 58 | ) 59 | 60 | @property 61 | def _execution_device(self): 62 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 63 | return self.device 64 | for module in self.unet.modules(): 65 | if ( 66 | hasattr(module, "_hf_hook") 67 | and hasattr(module._hf_hook, "execution_device") 68 | and module._hf_hook.execution_device is not None 69 | ): 70 | return torch.device(module._hf_hook.execution_device) 71 | return self.device 72 | 73 | def prepare_latents( 74 | self, 75 | batch_size: int, # Number of videos to generate in parallel 76 | num_channels_latents: int, # Number of channels in the latents 77 | width: int, # Width of the video frame 78 | height: int, # Height of the video frame 79 | video_length: int, # Length of the video in frames 80 | dtype: torch.dtype, # Data type of the latents 81 | device: torch.device, # Device to store the latents on 82 | generator: Optional[torch.Generator] = None, # Random number generator for reproducibility 83 | latents: Optional[torch.Tensor] = None, # Pre-generated latents (optional) 84 | ): 85 | shape = ( 86 | batch_size, 87 | num_channels_latents, 88 | video_length, 89 | height // self.vae_scale_factor, 90 | width // self.vae_scale_factor, 91 | ) 92 | if isinstance(generator, list) and len(generator) != batch_size: 93 | raise ValueError( 94 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 95 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 96 | ) 97 | 98 | if latents is None: 99 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 100 | else: 101 | latents = latents.to(device) 102 | 103 | # scale the initial noise by the standard deviation required by the scheduler 104 | if hasattr(self.scheduler, "init_noise_sigma"): 105 | latents = latents * self.scheduler.init_noise_sigma 106 | return latents 107 | 108 | def prepare_extra_step_kwargs(self, generator, eta): 109 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 110 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 111 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 112 | # and should be between [0, 1] 113 | 114 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 115 | extra_step_kwargs = {} 116 | if accepts_eta: 117 | extra_step_kwargs["eta"] = eta 118 | 119 | # check if the scheduler accepts generator 120 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 121 | if accepts_generator: 122 | extra_step_kwargs["generator"] = generator 123 | return extra_step_kwargs 124 | 125 | def decode_latents(self, latents): 126 | video_length = latents.shape[2] 127 | latents = 1 / 0.18215 * latents 128 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 129 | video = [] 130 | for frame_idx in range(latents.shape[0]): 131 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) 132 | video = torch.cat(video) 133 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 134 | video = (video / 2 + 0.5).clamp(0, 1) 135 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 136 | video = video.cpu().float().numpy() 137 | return video 138 | 139 | @torch.no_grad() 140 | def __call__( 141 | self, 142 | ref_image, 143 | face_emb, 144 | audio_tensor, 145 | width, 146 | height, 147 | video_length, 148 | num_inference_steps, 149 | guidance_scale, 150 | num_images_per_prompt=1, 151 | eta: float = 0.0, 152 | audio_emotion=None, 153 | emotion_class_num=None, 154 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 155 | output_type: Optional[str] = "tensor", 156 | return_dict: bool = True, 157 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 158 | callback_steps: Optional[int] = 1, 159 | ): 160 | # Default height and width to unet 161 | height = height or self.unet.config.sample_size * self.vae_scale_factor 162 | width = width or self.unet.config.sample_size * self.vae_scale_factor 163 | 164 | device = self._execution_device 165 | 166 | do_classifier_free_guidance = guidance_scale > 1.0 167 | 168 | # Prepare timesteps 169 | self.scheduler.set_timesteps(num_inference_steps, device=device) 170 | timesteps = self.scheduler.timesteps 171 | 172 | batch_size = 1 173 | 174 | # prepare clip image embeddings 175 | clip_image_embeds = face_emb 176 | clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype) 177 | 178 | encoder_hidden_states = self.image_proj(clip_image_embeds) 179 | uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds)) 180 | 181 | if do_classifier_free_guidance: 182 | encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0) 183 | 184 | num_channels_latents = self.diffusion_net.in_channels 185 | 186 | latents = self.prepare_latents( 187 | batch_size * num_images_per_prompt, 188 | num_channels_latents, 189 | width, 190 | height, 191 | video_length, 192 | clip_image_embeds.dtype, 193 | device, 194 | generator, 195 | ) 196 | 197 | # Prepare extra step kwargs. 198 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 199 | 200 | # Prepare ref image latents 201 | ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w") 202 | ref_image_tensor = self.ref_image_processor.preprocess( 203 | ref_image_tensor, height=height, width=width 204 | ) # (bs, c, width, height) 205 | ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device) 206 | # To save memory on GPUs like RTX 4090, we encode each frame separately 207 | # ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 208 | ref_image_latents = [] 209 | for frame_idx in range(ref_image_tensor.shape[0]): 210 | ref_image_latents.append(self.vae.encode(ref_image_tensor[frame_idx : frame_idx + 1]).latent_dist.mean) 211 | ref_image_latents = torch.cat(ref_image_latents, dim=0) 212 | 213 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 214 | 215 | if do_classifier_free_guidance: 216 | uncond_audio_tensor = torch.zeros_like(audio_tensor) 217 | audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0) 218 | audio_tensor = audio_tensor.to(dtype=self.diffusion_net.dtype, device=self.diffusion_net.device) 219 | 220 | # denoising loop 221 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 222 | with self.progress_bar(total=num_inference_steps) as progress_bar: 223 | for i in range(len(timesteps)): 224 | t = timesteps[i] 225 | # Forward reference image 226 | if i == 0: 227 | ref_features = self.reference_net( 228 | ref_image_latents.repeat((2 if do_classifier_free_guidance else 1), 1, 1, 1), 229 | torch.zeros_like(t), 230 | encoder_hidden_states=encoder_hidden_states, 231 | return_dict=False, 232 | ) 233 | 234 | # expand the latents if we are doing classifier free guidance 235 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 236 | if hasattr(self.scheduler, "scale_model_input"): 237 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 238 | 239 | audio_emotion = torch.tensor(torch.mode(audio_emotion).values.item()).to( 240 | dtype=torch.int, device=self.diffusion_net.device 241 | ) 242 | if do_classifier_free_guidance: 243 | uncond_audio_emotion = torch.full_like(audio_emotion, emotion_class_num) 244 | audio_emotion = torch.cat( 245 | [uncond_audio_emotion.unsqueeze(0), audio_emotion.unsqueeze(0)], 246 | dim=0, 247 | ) 248 | 249 | uc_mask = ( 250 | torch.Tensor( 251 | [1] * batch_size * num_images_per_prompt * 16 252 | + [0] * batch_size * num_images_per_prompt * 16 253 | ) 254 | .to(device) 255 | .bool() 256 | ) 257 | else: 258 | uc_mask = None 259 | 260 | noise_pred = self.diffusion_net( 261 | latent_model_input, 262 | ref_features, 263 | t, 264 | encoder_hidden_states=encoder_hidden_states, 265 | audio_embedding=audio_tensor, 266 | audio_emotion=audio_emotion, 267 | uc_mask=uc_mask, 268 | ).sample 269 | 270 | # perform guidance 271 | if do_classifier_free_guidance: 272 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 273 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 274 | 275 | # compute the previous noisy sample x_t -> x_t-1 276 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 277 | 278 | # call the callback, if provided 279 | if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: 280 | progress_bar.update() 281 | if callback is not None and i % callback_steps == 0: 282 | step_idx = i // getattr(self.scheduler, "order", 1) 283 | callback(step_idx, t, latents) 284 | 285 | # Post-processing 286 | images = self.decode_latents(latents) # (b, c, f, h, w) 287 | 288 | # Convert to tensor 289 | if output_type == "tensor": 290 | images = torch.from_numpy(images) 291 | 292 | if not return_dict: 293 | return images 294 | 295 | return VideoPipelineOutput(videos=images) 296 | -------------------------------------------------------------------------------- /memo/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/memo/utils/__init__.py -------------------------------------------------------------------------------- /memo/utils/audio_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import subprocess 5 | from io import BytesIO 6 | 7 | import librosa 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import torchaudio 12 | from audio_separator.separator import Separator 13 | from einops import rearrange 14 | from funasr.download.download_from_hub import download_model 15 | from funasr.models.emotion2vec.model import Emotion2vec 16 | from transformers import Wav2Vec2FeatureExtractor 17 | from safetensors.torch import load_file 18 | 19 | from memo.models.emotion_classifier import AudioEmotionClassifierModel 20 | from memo.models.wav2vec import Wav2VecModel 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int = 16000): 27 | p = subprocess.Popen( 28 | [ 29 | "ffmpeg", 30 | "-y", 31 | "-v", 32 | "error", 33 | "-i", 34 | input_audio_file, 35 | "-ar", 36 | str(sample_rate), 37 | output_audio_file, 38 | ] 39 | ) 40 | ret = p.wait() 41 | assert ret == 0, f"Resample audio failed! Input: {input_audio_file}, Output: {output_audio_file}" 42 | return output_audio_file 43 | 44 | 45 | @torch.no_grad() 46 | def preprocess_audio( 47 | wav_path: str, 48 | fps: int, 49 | wav2vec_model: str, 50 | vocal_separator_model: str = None, 51 | cache_dir: str = "", 52 | device: str = "cuda", 53 | sample_rate: int = 16000, 54 | num_generated_frames_per_clip: int = -1, 55 | ): 56 | """ 57 | Preprocess the audio file and extract audio embeddings. 58 | 59 | Args: 60 | wav_path (str): Path to the input audio file. 61 | fps (int): Frames per second for the audio processing. 62 | wav2vec_model (str): Path to the pretrained Wav2Vec model. 63 | vocal_separator_model (str, optional): Path to the vocal separator model. Defaults to None. 64 | cache_dir (str, optional): Directory for cached files. Defaults to "". 65 | device (str, optional): Device to use ('cuda' or 'cpu'). Defaults to "cuda". 66 | sample_rate (int, optional): Sampling rate for audio processing. Defaults to 16000. 67 | num_generated_frames_per_clip (int, optional): Number of generated frames per clip for padding. Defaults to -1. 68 | 69 | Returns: 70 | tuple: A tuple containing: 71 | - audio_emb (torch.Tensor): The processed audio embeddings. 72 | - audio_length (int): The length of the audio in frames. 73 | """ 74 | # Initialize Wav2Vec model 75 | audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model).to(device=device) 76 | audio_encoder.feature_extractor._freeze_parameters() 77 | 78 | # Initialize Wav2Vec feature extractor 79 | wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model) 80 | 81 | # Initialize vocal separator if provided 82 | vocal_separator = None 83 | if vocal_separator_model is not None: 84 | os.makedirs(cache_dir, exist_ok=True) 85 | vocal_separator = Separator( 86 | output_dir=cache_dir, 87 | output_single_stem="vocals", 88 | model_file_dir=os.path.dirname(vocal_separator_model), 89 | ) 90 | vocal_separator.load_model(os.path.basename(vocal_separator_model)) 91 | assert vocal_separator.model_instance is not None, "Failed to load audio separation model." 92 | 93 | # Perform vocal separation if applicable 94 | if vocal_separator is not None: 95 | outputs = vocal_separator.separate(wav_path) 96 | assert len(outputs) > 0, "Audio separation failed." 97 | vocal_audio_file = outputs[0] 98 | vocal_audio_name, _ = os.path.splitext(vocal_audio_file) 99 | vocal_audio_file = os.path.join(vocal_separator.output_dir, vocal_audio_file) 100 | vocal_audio_file = resample_audio( 101 | vocal_audio_file, 102 | os.path.join(vocal_separator.output_dir, f"{vocal_audio_name}-16k.wav"), 103 | sample_rate, 104 | ) 105 | else: 106 | vocal_audio_file = wav_path 107 | 108 | # Load audio and extract Wav2Vec features 109 | speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=sample_rate) 110 | audio_feature = np.squeeze(wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values) 111 | audio_length = math.ceil(len(audio_feature) / sample_rate * fps) 112 | audio_feature = torch.from_numpy(audio_feature).float().to(device=device) 113 | 114 | # Pad audio features to match the required length 115 | if num_generated_frames_per_clip > 0 and audio_length % num_generated_frames_per_clip != 0: 116 | audio_feature = torch.nn.functional.pad( 117 | audio_feature, 118 | ( 119 | 0, 120 | (num_generated_frames_per_clip - audio_length % num_generated_frames_per_clip) * (sample_rate // fps), 121 | ), 122 | "constant", 123 | 0.0, 124 | ) 125 | audio_length += num_generated_frames_per_clip - audio_length % num_generated_frames_per_clip 126 | audio_feature = audio_feature.unsqueeze(0) 127 | 128 | # Extract audio embeddings 129 | with torch.no_grad(): 130 | embeddings = audio_encoder(audio_feature, seq_len=audio_length, output_hidden_states=True) 131 | assert len(embeddings) > 0, "Failed to extract audio embeddings." 132 | audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) 133 | audio_emb = rearrange(audio_emb, "b s d -> s b d") 134 | 135 | # Concatenate embeddings with surrounding frames 136 | audio_emb = audio_emb.cpu().detach() 137 | concatenated_tensors = [] 138 | for i in range(audio_emb.shape[0]): 139 | vectors_to_concat = [audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)] for j in range(-2, 3)] 140 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) 141 | audio_emb = torch.stack(concatenated_tensors, dim=0) 142 | 143 | if vocal_separator is not None: 144 | del vocal_separator 145 | del audio_encoder 146 | 147 | return audio_emb, audio_length 148 | 149 | 150 | @torch.no_grad() 151 | def extract_audio_emotion_labels( 152 | model: str, 153 | wav_path: str, 154 | emotion2vec_model: str, 155 | audio_length: int, 156 | sample_rate: int = 16000, 157 | device: str = "cuda", 158 | ): 159 | """ 160 | Extract audio emotion labels from an audio file. 161 | """ 162 | # Load models 163 | logger.info("Downloading emotion2vec models from modelscope") 164 | kwargs = download_model(model=emotion2vec_model) 165 | kwargs["tokenizer"] = None 166 | kwargs["input_size"] = None 167 | kwargs["frontend"] = None 168 | emotion_model = Emotion2vec(**kwargs, vocab_size=-1).to(device) 169 | init_param = kwargs.get("init_param", None) 170 | load_emotion2vec_model( 171 | model=emotion_model, 172 | path=init_param, 173 | ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), 174 | oss_bucket=kwargs.get("oss_bucket", None), 175 | scope_map=kwargs.get("scope_map", []), 176 | ) 177 | emotion_model.eval() 178 | 179 | # Create and load emotion classifier directly 180 | classifier = AudioEmotionClassifierModel().to(device=device) 181 | emotion_classifier_path = os.path.join( 182 | model, 183 | "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors" 184 | ) 185 | classifier.load_state_dict(load_file(emotion_classifier_path)) 186 | classifier.eval() 187 | 188 | # Load audio 189 | wav, sr = torchaudio.load(wav_path) 190 | if sr != sample_rate: 191 | wav = torchaudio.functional.resample(wav, sr, sample_rate) 192 | wav = wav.view(-1) if wav.dim() == 1 else wav[0].view(-1) 193 | 194 | emotion_labels = torch.full_like(wav, -1, dtype=torch.int32) 195 | 196 | def extract_emotion(x): 197 | """ 198 | Extract emotion for a given audio segment. 199 | """ 200 | x = x.to(device=device) 201 | x = F.layer_norm(x, x.shape).view(1, -1) 202 | feats = emotion_model.extract_features(x) 203 | x = feats["x"].mean(dim=1) # average across frames 204 | x = classifier(x) 205 | x = torch.softmax(x, dim=-1) 206 | return torch.argmax(x, dim=-1) 207 | 208 | # Process start, middle, and end segments 209 | start_label = extract_emotion(wav[: sample_rate * 2]).item() 210 | emotion_labels[:sample_rate] = start_label 211 | 212 | for i in range(sample_rate, len(wav) - sample_rate, sample_rate): 213 | mid_wav = wav[i - sample_rate : i - sample_rate + sample_rate * 3] 214 | mid_label = extract_emotion(mid_wav).item() 215 | emotion_labels[i : i + sample_rate] = mid_label 216 | 217 | end_label = extract_emotion(wav[-sample_rate * 2 :]).item() 218 | emotion_labels[-sample_rate:] = end_label 219 | 220 | # Interpolate to match the target audio length 221 | emotion_labels = emotion_labels.unsqueeze(0).unsqueeze(0).float() 222 | emotion_labels = F.interpolate(emotion_labels, size=audio_length, mode="nearest").squeeze(0).squeeze(0).int() 223 | num_emotion_classes = classifier.num_emotion_classes 224 | 225 | del emotion_model 226 | del classifier 227 | 228 | return emotion_labels, num_emotion_classes 229 | 230 | 231 | def load_emotion2vec_model( 232 | path: str, 233 | model: torch.nn.Module, 234 | ignore_init_mismatch: bool = True, 235 | map_location: str = "cpu", 236 | oss_bucket=None, 237 | scope_map=[], 238 | ): 239 | obj = model 240 | dst_state = obj.state_dict() 241 | logger.debug(f"Emotion2vec checkpoint: {path}") 242 | if oss_bucket is None: 243 | src_state = torch.load(path, map_location=map_location) 244 | else: 245 | buffer = BytesIO(oss_bucket.get_object(path).read()) 246 | src_state = torch.load(buffer, map_location=map_location) 247 | 248 | src_state = src_state["state_dict"] if "state_dict" in src_state else src_state 249 | src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state 250 | src_state = src_state["model"] if "model" in src_state else src_state 251 | 252 | if isinstance(scope_map, str): 253 | scope_map = scope_map.split(",") 254 | scope_map += ["module.", "None"] 255 | 256 | for k in dst_state.keys(): 257 | k_src = k 258 | if scope_map is not None: 259 | src_prefix = "" 260 | dst_prefix = "" 261 | for i in range(0, len(scope_map), 2): 262 | src_prefix = scope_map[i] if scope_map[i].lower() != "none" else "" 263 | dst_prefix = scope_map[i + 1] if scope_map[i + 1].lower() != "none" else "" 264 | 265 | if dst_prefix == "" and (src_prefix + k) in src_state.keys(): 266 | k_src = src_prefix + k 267 | if not k_src.startswith("module."): 268 | logger.debug(f"init param, map: {k} from {k_src} in ckpt") 269 | elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys(): 270 | k_src = k.replace(dst_prefix, src_prefix, 1) 271 | if not k_src.startswith("module."): 272 | logger.debug(f"init param, map: {k} from {k_src} in ckpt") 273 | 274 | if k_src in src_state.keys(): 275 | if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape: 276 | logger.debug( 277 | f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}" 278 | ) 279 | else: 280 | dst_state[k] = src_state[k_src] 281 | 282 | else: 283 | logger.debug(f"Warning, miss key in ckpt: {k}, mapped: {k_src}") 284 | 285 | obj.load_state_dict(dst_state, strict=True) 286 | -------------------------------------------------------------------------------- /memo/utils/vision_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from insightface.app import FaceAnalysis 8 | from moviepy.editor import AudioFileClip, VideoClip 9 | from PIL import Image 10 | from torchvision import transforms 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def tensor_to_video(tensor, output_video_path, input_audio_path, fps=30): 17 | """ 18 | Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file. 19 | 20 | Args: 21 | tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w]. 22 | output_video_path (str): The file path where the output video will be saved. 23 | input_audio_path (str): The path to the audio file (WAV file) that contains the audio track to be added. 24 | fps (int): The frame rate of the output video. Default is 30 fps. 25 | """ 26 | tensor = tensor.permute(1, 2, 3, 0).cpu().numpy() # convert to [f, h, w, c] 27 | tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8) # to [0, 255] 28 | 29 | def make_frame(t): 30 | frame_index = min(int(t * fps), tensor.shape[0] - 1) 31 | return tensor[frame_index] 32 | 33 | video_duration = tensor.shape[0] / fps 34 | audio_clip = AudioFileClip(input_audio_path) 35 | audio_duration = audio_clip.duration 36 | final_duration = min(video_duration, audio_duration) 37 | audio_clip = audio_clip.subclip(0, final_duration) 38 | new_video_clip = VideoClip(make_frame, duration=final_duration) 39 | new_video_clip = new_video_clip.set_audio(audio_clip) 40 | new_video_clip.write_videofile(output_video_path, fps=fps, audio_codec="aac") 41 | 42 | 43 | @torch.no_grad() 44 | def preprocess_image(face_analysis_model, image_path, image_size): 45 | """Preprocess image for MEMO pipeline""" 46 | # Modify face analysis initialization 47 | face_analysis = FaceAnalysis( 48 | name="", 49 | root=face_analysis_model, # Use parent directory 50 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] 51 | ) 52 | face_analysis.prepare(ctx_id=0, det_size=(640, 640)) 53 | 54 | # Define the image transformation 55 | transform = transforms.Compose( 56 | [ 57 | transforms.Resize((image_size, image_size)), 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.5], [0.5]), 60 | ] 61 | ) 62 | 63 | # Load and preprocess the image 64 | image = Image.open(image_path).convert("RGB") 65 | pixel_values = transform(image) 66 | pixel_values = pixel_values.unsqueeze(0) 67 | 68 | # Detect faces and extract the face embedding 69 | image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 70 | faces = face_analysis.get(image_bgr) 71 | if not faces: 72 | logger.warning("No faces detected in the image. Using a zero vector as the face embedding.") 73 | face_emb = np.zeros(512) 74 | else: 75 | # Sort faces by size and select the largest one 76 | faces_sorted = sorted( 77 | faces, 78 | key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), 79 | reverse=True, 80 | ) 81 | if "embedding" not in faces_sorted[0]: 82 | logger.warning("The detected face does not have an 'embedding'. Using a zero vector.") 83 | face_emb = np.zeros(512) 84 | else: 85 | face_emb = faces_sorted[0]["embedding"] 86 | 87 | # Convert face embedding to a PyTorch tensor 88 | face_emb = face_emb.reshape(1, -1) 89 | face_emb = torch.tensor(face_emb) 90 | 91 | del face_analysis 92 | 93 | return pixel_values, face_emb 94 | -------------------------------------------------------------------------------- /memo_model_manager.py: -------------------------------------------------------------------------------- 1 | #memo_model_manager.py 2 | import os 3 | import logging 4 | import json 5 | import shutil 6 | from huggingface_hub import hf_hub_download 7 | from modelscope import snapshot_download 8 | import folder_paths 9 | 10 | logger = logging.getLogger("memo") 11 | 12 | class MemoModelManager: 13 | def __init__(self): 14 | self.models_base = folder_paths.models_dir 15 | self._setup_paths() 16 | self._ensure_model_structure() 17 | 18 | def _setup_paths(self): 19 | """Initialize base paths structure""" 20 | self.paths = { 21 | "memo_base": os.path.join(self.models_base, "checkpoints", "memo"), 22 | "wav2vec": os.path.join(self.models_base, "wav2vec", "facebook", "wav2vec2-base-960h"), 23 | "emotion2vec": os.path.join(self.models_base, "emotion2vec", "iic", "emotion2vec_plus_large"), 24 | "vae": os.path.join(self.models_base, "vae", "stabilityai", "sd-vae-ft-mse") 25 | } 26 | 27 | # Create directories 28 | for path in self.paths.values(): 29 | os.makedirs(path, exist_ok=True) 30 | 31 | # Create memo subfolders 32 | for subdir in ["reference_net", "diffusion_net", "image_proj", "audio_proj", 33 | "misc/audio_emotion_classifier", "misc/face_analysis", "misc/vocal_separator"]: 34 | os.makedirs(os.path.join(self.paths["memo_base"], subdir), exist_ok=True) 35 | 36 | def _direct_download(self, repo_id, filename, target_path, force=False): 37 | """Download directly to target path without extra nesting""" 38 | try: 39 | if not force and os.path.exists(target_path): 40 | return target_path 41 | 42 | download_path = hf_hub_download( 43 | repo_id=repo_id, 44 | filename=filename, 45 | local_dir=os.path.dirname(target_path), 46 | local_dir_use_symlinks=False 47 | ) 48 | 49 | # Move if downloaded to wrong location 50 | if download_path != target_path: 51 | shutil.move(download_path, target_path) 52 | 53 | return target_path 54 | except Exception as e: 55 | logger.warning(f"Failed to download {filename} from {repo_id} to {target_path}: {e}") 56 | return None 57 | 58 | def _setup_face_analysis(self): 59 | """Setup face analysis models with correct structure""" 60 | face_dir = os.path.join(self.paths["memo_base"], "misc", "face_analysis") 61 | models_dir = os.path.join(face_dir, "models") # Create a models subdirectory 62 | os.makedirs(models_dir, exist_ok=True) 63 | 64 | # Create models.json 65 | models_json = { 66 | "detection": ["scrfd_10g_bnkps"], 67 | "recognition": ["glintr100"], 68 | "analysis": ["genderage", "2d106det", "1k3d68"] 69 | } 70 | 71 | # Write models.json 72 | models_json_path = os.path.join(face_dir, "models.json") 73 | with open(models_json_path, "w") as f: 74 | json.dump(models_json, f, indent=2) 75 | 76 | # Create version.txt 77 | version_path = os.path.join(face_dir, "version.txt") 78 | with open(version_path, "w") as f: 79 | f.write("0.7.3") 80 | 81 | # Download model files if they don't exist 82 | required_models = { 83 | "scrfd_10g_bnkps.onnx": "scrfd_10g_bnkps", 84 | "glintr100.onnx": "glintr100", 85 | "genderage.onnx": "genderage", 86 | "2d106det.onnx": "2d106det", 87 | "1k3d68.onnx": "1k3d68", 88 | "face_landmarker_v2_with_blendshapes.task": "face_landmarker" 89 | } 90 | 91 | for model_file, model_name in required_models.items(): 92 | target_path = os.path.join(models_dir, model_file) # Save in models subdirectory 93 | if not os.path.exists(target_path): 94 | self._direct_download( 95 | "memoavatar/memo", 96 | f"misc/face_analysis/models/{model_file}", 97 | target_path 98 | ) 99 | # Create symlink in parent directory for compatibility 100 | parent_target = os.path.join(face_dir, model_file) 101 | if not os.path.exists(parent_target): 102 | if os.name == 'nt': # Windows 103 | import shutil 104 | shutil.copy2(target_path, parent_target) 105 | else: # Unix-like 106 | os.symlink(target_path, parent_target) 107 | 108 | # Set environment variable for face models 109 | os.environ["MEMO_FACE_MODELS"] = face_dir 110 | return face_dir 111 | 112 | def _ensure_model_structure(self): 113 | """Download all required models to correct locations""" 114 | # Set up face analysis and environment variables first 115 | face_dir = self._setup_face_analysis() 116 | os.environ["MEMO_FACE_MODELS"] = face_dir 117 | os.environ["MEMO_VOCAL_MODEL"] = os.path.join( 118 | self.paths["memo_base"], "misc/vocal_separator/Kim_Vocal_2.onnx" 119 | ) 120 | 121 | # Download memo components 122 | components = { 123 | "reference_net": ["config.json", "diffusion_pytorch_model.safetensors"], 124 | "diffusion_net": ["config.json", "diffusion_pytorch_model.safetensors"], 125 | "image_proj": ["config.json", "diffusion_pytorch_model.safetensors"], 126 | "audio_proj": ["config.json", "diffusion_pytorch_model.safetensors"] 127 | } 128 | 129 | for component, files in components.items(): 130 | component_dir = os.path.join(self.paths["memo_base"], component) 131 | for file in files: 132 | self._direct_download( 133 | "memoavatar/memo", 134 | f"{component}/{file}", 135 | os.path.join(component_dir, file) 136 | ) 137 | 138 | # Download vocal separator 139 | self._direct_download( 140 | "memoavatar/memo", 141 | "misc/vocal_separator/Kim_Vocal_2.onnx", 142 | os.path.join(self.paths["memo_base"], "misc/vocal_separator/Kim_Vocal_2.onnx") 143 | ) 144 | 145 | # Download emotion classifier 146 | self._direct_download( 147 | "memoavatar/memo", 148 | "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors", 149 | os.path.join(self.paths["memo_base"], "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors") 150 | ) 151 | 152 | # Download wav2vec files 153 | for file in ["config.json", "preprocessor_config.json", "pytorch_model.bin", 154 | "special_tokens_map.json", "tokenizer_config.json", "vocab.json"]: 155 | self._direct_download( 156 | "facebook/wav2vec2-base-960h", 157 | file, 158 | os.path.join(self.paths["wav2vec"], file) 159 | ) 160 | 161 | # Download emotion2vec 162 | try: 163 | snapshot_download( 164 | "iic/emotion2vec_plus_large", 165 | local_dir=self.paths["emotion2vec"] 166 | ) 167 | except Exception as e: 168 | logger.warning(f"Failed to download emotion2vec model: {e}") 169 | 170 | # Download VAE 171 | for file in ["config.json", "diffusion_pytorch_model.safetensors"]: 172 | self._direct_download( 173 | "stabilityai/sd-vae-ft-mse", 174 | file, 175 | os.path.join(self.paths["vae"], file) 176 | ) 177 | 178 | def get_model_paths(self): 179 | """Return paths dictionary""" 180 | return { 181 | "memo_base": self.paths["memo_base"], 182 | "face_models": os.path.join(self.paths["memo_base"], "misc/face_analysis"), 183 | "vocal_separator": os.path.join(self.paths["memo_base"], "misc/vocal_separator/Kim_Vocal_2.onnx"), 184 | "wav2vec": self.paths["wav2vec"], 185 | "emotion2vec": self.paths["emotion2vec"], 186 | "vae": self.paths["vae"] 187 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_if_memoavatar" 3 | description = "Talking avatars MemoAvatar Memory-Guided Diffusion for Expressive Talking Video Generation" 4 | version = "0.0.3" 5 | license = { file = "MIT License" } 6 | dependencies = [ 7 | "albumentations==1.4.21", 8 | "modelscope==1.20.1", 9 | "numba==0.60.0", 10 | "librosa==0.10.2", 11 | "diffusers>=0.31.0", 12 | "transformers>=4.46.3", 13 | "numpy>=1.26.4", 14 | "PyYAML>=6.0.1", 15 | "moviepy==1.0.3", 16 | "pillow>=10.4.0", 17 | "funasr==1.0.27", 18 | "insightface==0.7.3", 19 | "accelerate==1.1.1", 20 | "black==23.12.1", 21 | "ffmpeg-python==0.2.0", 22 | "huggingface-hub==0.26.2", 23 | "imageio==2.36.0", 24 | "imageio-ffmpeg==0.5.1", 25 | "hydra-core==1.3.2", 26 | "jax==0.4.35", 27 | "mediapipe==0.10.18", 28 | "omegaconf==2.3.0", 29 | "onnxruntime-gpu>=1.20.1", 30 | "opencv-python-headless==4.10.0.84", 31 | "scikit-learn>=1.5.2", 32 | "scipy>=1.14.1", 33 | "tqdm>=4.67.1", 34 | ] 35 | 36 | [project.urls] 37 | Repository = "https://github.com/if-ai/ComfyUI-IF_MemoAvatar" 38 | 39 | [tool.comfy] 40 | PublisherId = "impactframes" 41 | DisplayName = "IF_MemoAvatar" 42 | Icon = "https://impactframes.ai/System/Icons/48x48/if.png" 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git-lfs 2 | diffusers>=0.31.0 3 | audio-separator 4 | albumentations 5 | numba 6 | librosa 7 | modelscope 8 | transformers>=4.46.3 9 | numpy>=1.26.4 10 | PyYAML>=6.0.1 11 | moviepy>=1.0.3 12 | pillow>=10.4.0 13 | librosa==0.10.2 14 | audio-separator==0.24.1 15 | funasr==1.0.27 16 | modelscope 17 | insightface==0.7.3 18 | accelerate==1.1.1 19 | albumentations==1.4.21 20 | black==23.12.1 21 | einops==0.8.0 22 | ffmpeg-python==0.2.0 23 | huggingface-hub==0.26.2 24 | imageio==2.36.0 25 | imageio-ffmpeg==0.5.1 26 | hydra-core==1.3.2 27 | jax==0.4.35 28 | mediapipe==0.10.18 29 | modelscope==1.20.1 30 | omegaconf==2.3.0 31 | onnxruntime>=1.20.1 32 | onnxruntime-gpu>=1.20.1 33 | opencv-python-headless==4.10.0.84 34 | scikit-learn>=1.5.2 35 | scipy>=1.14.1 36 | tqdm>=4.67.1 37 | -------------------------------------------------------------------------------- /web/js/IF_MemoAvatar.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | 3 | // Create base styles for buttons 4 | const style = document.createElement('style'); 5 | style.textContent = ` 6 | .if-button { 7 | background: var(--comfy-input-bg); 8 | border: 1px solid var(--border-color); 9 | color: var(--input-text); 10 | padding: 4px 12px; 11 | border-radius: 4px; 12 | cursor: pointer; 13 | transition: all 0.2s ease; 14 | margin-right: 5px; 15 | } 16 | 17 | .if-button:hover { 18 | background: var(--comfy-input-bg-hover); 19 | } 20 | `; 21 | document.head.appendChild(style); 22 | 23 | app.registerExtension({ 24 | name: "Comfy.IF_MemoAvatar", 25 | async beforeRegisterNodeDef(nodeType, nodeData) { 26 | if (nodeData.name !== "IF_MemoAvatar") return; 27 | 28 | const origOnNodeCreated = nodeType.prototype.onNodeCreated; 29 | 30 | nodeType.prototype.onNodeCreated = function() { 31 | const result = origOnNodeCreated?.apply(this, arguments); 32 | 33 | // Add preview widget 34 | this.addWidget("preview", "preview", { 35 | serialize: false, 36 | size: [256, 256] 37 | }); 38 | 39 | // Ensure proper widget parent references 40 | if (this.widgets) { 41 | this.widgets.forEach(w => w.parent = this); 42 | } 43 | 44 | return result; 45 | }; 46 | 47 | // Add size handling 48 | nodeType.prototype.onResize = function(size) { 49 | const minWidth = 400; 50 | const minHeight = 200; 51 | size[0] = Math.max(size[0], minWidth); 52 | size[1] = Math.max(size[1], minHeight); 53 | }; 54 | 55 | // Add execution handling 56 | nodeType.prototype.onExecuted = function(message) { 57 | if (message.preview) { 58 | const previewWidget = this.widgets.find(w => w.name === "preview"); 59 | if (previewWidget) { 60 | previewWidget.value = message.preview; 61 | } 62 | } 63 | }; 64 | } 65 | }); -------------------------------------------------------------------------------- /workflow/IF_MemoAvatar.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 67, 3 | "last_link_id": 114, 4 | "nodes": [ 5 | { 6 | "id": 49, 7 | "type": "Note", 8 | "pos": [ 9 | 1176.2391357421875, 10 | 654.4822387695312 11 | ], 12 | "size": [ 13 | 210, 14 | 156.80001831054688 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [], 21 | "properties": {}, 22 | "widgets_values": [ 23 | "This does not populate but the video get saved on the output folder \nMaybe load it up with VHS node" 24 | ], 25 | "color": "#432", 26 | "bgcolor": "#653" 27 | }, 28 | { 29 | "id": 56, 30 | "type": "PreviewImage", 31 | "pos": [ 32 | -174.07064819335938, 33 | 707.8036499023438 34 | ], 35 | "size": [ 36 | 211.03074645996094, 37 | 246 38 | ], 39 | "flags": {}, 40 | "order": 8, 41 | "mode": 0, 42 | "inputs": [ 43 | { 44 | "name": "images", 45 | "type": "IMAGE", 46 | "link": 90 47 | } 48 | ], 49 | "outputs": [], 50 | "properties": { 51 | "Node name for S&R": "PreviewImage" 52 | }, 53 | "widgets_values": [] 54 | }, 55 | { 56 | "id": 54, 57 | "type": "ImageResizeKJ", 58 | "pos": [ 59 | -164.07064819335938, 60 | 157.803466796875 61 | ], 62 | "size": [ 63 | 210, 64 | 238 65 | ], 66 | "flags": { 67 | "collapsed": false 68 | }, 69 | "order": 6, 70 | "mode": 0, 71 | "inputs": [ 72 | { 73 | "name": "image", 74 | "type": "IMAGE", 75 | "link": 84 76 | }, 77 | { 78 | "name": "get_image_size", 79 | "type": "IMAGE", 80 | "link": null, 81 | "shape": 7 82 | }, 83 | { 84 | "name": "width_input", 85 | "type": "INT", 86 | "link": null, 87 | "widget": { 88 | "name": "width_input" 89 | }, 90 | "shape": 7 91 | }, 92 | { 93 | "name": "height_input", 94 | "type": "INT", 95 | "link": null, 96 | "widget": { 97 | "name": "height_input" 98 | }, 99 | "shape": 7 100 | } 101 | ], 102 | "outputs": [ 103 | { 104 | "name": "IMAGE", 105 | "type": "IMAGE", 106 | "links": [ 107 | 89 108 | ], 109 | "slot_index": 0 110 | }, 111 | { 112 | "name": "width", 113 | "type": "INT", 114 | "links": null 115 | }, 116 | { 117 | "name": "height", 118 | "type": "INT", 119 | "links": null 120 | } 121 | ], 122 | "properties": { 123 | "Node name for S&R": "ImageResizeKJ" 124 | }, 125 | "widgets_values": [ 126 | 1504, 127 | 1504, 128 | "lanczos", 129 | true, 130 | 2, 131 | 0, 132 | 0, 133 | "disabled" 134 | ] 135 | }, 136 | { 137 | "id": 57, 138 | "type": "ImageCrop+", 139 | "pos": [ 140 | -174.07064819335938, 141 | 457.80340576171875 142 | ], 143 | "size": [ 144 | 210, 145 | 194 146 | ], 147 | "flags": {}, 148 | "order": 7, 149 | "mode": 0, 150 | "inputs": [ 151 | { 152 | "name": "image", 153 | "type": "IMAGE", 154 | "link": 89 155 | } 156 | ], 157 | "outputs": [ 158 | { 159 | "name": "IMAGE", 160 | "type": "IMAGE", 161 | "links": [ 162 | 90, 163 | 112 164 | ], 165 | "slot_index": 0 166 | }, 167 | { 168 | "name": "x", 169 | "type": "INT", 170 | "links": null 171 | }, 172 | { 173 | "name": "y", 174 | "type": "INT", 175 | "links": null 176 | } 177 | ], 178 | "properties": { 179 | "Node name for S&R": "ImageCrop+" 180 | }, 181 | "widgets_values": [ 182 | 1400, 183 | 1400, 184 | "top-center", 185 | -2, 186 | 51 187 | ] 188 | }, 189 | { 190 | "id": 59, 191 | "type": "Fast Groups Muter (rgthree)", 192 | "pos": [ 193 | 70.64138793945312, 194 | 937.8055419921875 195 | ], 196 | "size": [ 197 | 226.8000030517578, 198 | 106 199 | ], 200 | "flags": {}, 201 | "order": 1, 202 | "mode": 0, 203 | "inputs": [], 204 | "outputs": [ 205 | { 206 | "name": "OPT_CONNECTION", 207 | "type": "*", 208 | "links": null 209 | } 210 | ], 211 | "properties": { 212 | "matchColors": "", 213 | "matchTitle": "", 214 | "showNav": true, 215 | "sort": "position", 216 | "customSortAlphabet": "", 217 | "toggleRestriction": "default" 218 | } 219 | }, 220 | { 221 | "id": 32, 222 | "type": "IF_MemoCheckpointLoader", 223 | "pos": [ 224 | 72.47502136230469, 225 | 424.5469665527344 226 | ], 227 | "size": [ 228 | 315, 229 | 206 230 | ], 231 | "flags": {}, 232 | "order": 2, 233 | "mode": 0, 234 | "inputs": [], 235 | "outputs": [ 236 | { 237 | "name": "reference_net", 238 | "type": "MODEL", 239 | "links": [ 240 | 100 241 | ], 242 | "slot_index": 0 243 | }, 244 | { 245 | "name": "diffusion_net", 246 | "type": "MODEL", 247 | "links": [ 248 | 101 249 | ], 250 | "slot_index": 1 251 | }, 252 | { 253 | "name": "vae", 254 | "type": "VAE", 255 | "links": [ 256 | 102 257 | ], 258 | "slot_index": 2 259 | }, 260 | { 261 | "name": "image_proj", 262 | "type": "IMAGE_PROJ", 263 | "links": [ 264 | 103 265 | ], 266 | "slot_index": 3 267 | }, 268 | { 269 | "name": "audio_proj", 270 | "type": "AUDIO_PROJ", 271 | "links": [ 272 | 104 273 | ], 274 | "slot_index": 4 275 | }, 276 | { 277 | "name": "emotion_classifier", 278 | "type": "EMOTION_CLASSIFIER", 279 | "links": [ 280 | 105 281 | ], 282 | "slot_index": 5 283 | } 284 | ], 285 | "properties": { 286 | "Node name for S&R": "IF_MemoCheckpointLoader" 287 | }, 288 | "widgets_values": [ 289 | "memoavatar/memo" 290 | ] 291 | }, 292 | { 293 | "id": 60, 294 | "type": "LoadAudio", 295 | "pos": [ 296 | 68.33649444580078, 297 | 675.4642944335938 298 | ], 299 | "size": [ 300 | 315, 301 | 124 302 | ], 303 | "flags": {}, 304 | "order": 3, 305 | "mode": 0, 306 | "inputs": [], 307 | "outputs": [ 308 | { 309 | "name": "AUDIO", 310 | "type": "AUDIO", 311 | "links": [ 312 | 111 313 | ], 314 | "slot_index": 0 315 | } 316 | ], 317 | "properties": { 318 | "Node name for S&R": "LoadAudio" 319 | }, 320 | "widgets_values": [ 321 | "Hola.wav", 322 | null, 323 | "" 324 | ] 325 | }, 326 | { 327 | "id": 50, 328 | "type": "Note", 329 | "pos": [ 330 | 446.33880615234375, 331 | 203.38783264160156 332 | ], 333 | "size": [ 334 | 331.2768249511719, 335 | 117.3055648803711 336 | ], 337 | "flags": {}, 338 | "order": 4, 339 | "mode": 0, 340 | "inputs": [], 341 | "outputs": [], 342 | "properties": {}, 343 | "widgets_values": [ 344 | "3090 worked with 384 faster\n512 slower \nreducing fps braked the video but needs testing and optimizations \nThis is a wrapper is using diffusers and same code as the Original Repo I only made a few changes it should work on 3090 or maybe even in lesser cards if chosen a smaller res" 345 | ], 346 | "color": "#432", 347 | "bgcolor": "#653" 348 | }, 349 | { 350 | "id": 31, 351 | "type": "IF_DisplayText", 352 | "pos": [ 353 | 851.056640625, 354 | 660.5128784179688 355 | ], 356 | "size": [ 357 | 295.78253173828125, 358 | 228.17208862304688 359 | ], 360 | "flags": {}, 361 | "order": 11, 362 | "mode": 0, 363 | "inputs": [ 364 | { 365 | "name": "text", 366 | "type": "STRING", 367 | "link": 108, 368 | "widget": { 369 | "name": "text" 370 | } 371 | } 372 | ], 373 | "outputs": [ 374 | { 375 | "name": "text", 376 | "type": "STRING", 377 | "links": null, 378 | "tooltip": "Full text content" 379 | }, 380 | { 381 | "name": "text_list", 382 | "type": "STRING", 383 | "links": null, 384 | "shape": 6, 385 | "tooltip": "Individual lines as separate outputs" 386 | }, 387 | { 388 | "name": "count", 389 | "type": "INT", 390 | "links": null, 391 | "tooltip": "Total number of non-empty lines" 392 | }, 393 | { 394 | "name": "selected", 395 | "type": "STRING", 396 | "links": null, 397 | "tooltip": "Currently selected line based on select input" 398 | } 399 | ], 400 | "properties": { 401 | "Node name for S&R": "IF_DisplayText" 402 | }, 403 | "widgets_values": [ 404 | "", 405 | 0, 406 | "✅ Video saved as memo_video_20241216-193105.mp4" 407 | ] 408 | }, 409 | { 410 | "id": 61, 411 | "type": "IF_MemoAvatar", 412 | "pos": [ 413 | 424.6819152832031, 414 | 392.4902648925781 415 | ], 416 | "size": [ 417 | 400, 418 | 390 419 | ], 420 | "flags": {}, 421 | "order": 9, 422 | "mode": 0, 423 | "inputs": [ 424 | { 425 | "name": "image", 426 | "type": "IMAGE", 427 | "link": 112 428 | }, 429 | { 430 | "name": "audio", 431 | "type": "AUDIO", 432 | "link": 111 433 | }, 434 | { 435 | "name": "reference_net", 436 | "type": "MODEL", 437 | "link": 100 438 | }, 439 | { 440 | "name": "diffusion_net", 441 | "type": "MODEL", 442 | "link": 101 443 | }, 444 | { 445 | "name": "vae", 446 | "type": "VAE", 447 | "link": 102 448 | }, 449 | { 450 | "name": "image_proj", 451 | "type": "IMAGE_PROJ", 452 | "link": 103 453 | }, 454 | { 455 | "name": "audio_proj", 456 | "type": "AUDIO_PROJ", 457 | "link": 104 458 | }, 459 | { 460 | "name": "emotion_classifier", 461 | "type": "EMOTION_CLASSIFIER", 462 | "link": 105 463 | } 464 | ], 465 | "outputs": [ 466 | { 467 | "name": "video_path", 468 | "type": "STRING", 469 | "links": [ 470 | 114 471 | ], 472 | "slot_index": 0 473 | }, 474 | { 475 | "name": "status", 476 | "type": "STRING", 477 | "links": [ 478 | 108 479 | ], 480 | "slot_index": 1 481 | } 482 | ], 483 | "properties": { 484 | "Node name for S&R": "IF_MemoAvatar" 485 | }, 486 | "widgets_values": [ 487 | 512, 488 | 16, 489 | 30, 490 | 20, 491 | 3.5, 492 | 35, 493 | "randomize", 494 | "memo_video", 495 | { 496 | "serialize": false, 497 | "size": [ 498 | 256, 499 | 256 500 | ] 501 | } 502 | ] 503 | }, 504 | { 505 | "id": 67, 506 | "type": "IF_DisplayText", 507 | "pos": [ 508 | 855.5243530273438, 509 | 369.9153137207031 510 | ], 511 | "size": [ 512 | 295.78253173828125, 513 | 228.17208862304688 514 | ], 515 | "flags": {}, 516 | "order": 10, 517 | "mode": 0, 518 | "inputs": [ 519 | { 520 | "name": "text", 521 | "type": "STRING", 522 | "link": 114, 523 | "widget": { 524 | "name": "text" 525 | } 526 | } 527 | ], 528 | "outputs": [ 529 | { 530 | "name": "text", 531 | "type": "STRING", 532 | "links": null, 533 | "tooltip": "Full text content" 534 | }, 535 | { 536 | "name": "text_list", 537 | "type": "STRING", 538 | "links": null, 539 | "shape": 6, 540 | "tooltip": "Individual lines as separate outputs" 541 | }, 542 | { 543 | "name": "count", 544 | "type": "INT", 545 | "links": null, 546 | "tooltip": "Total number of non-empty lines" 547 | }, 548 | { 549 | "name": "selected", 550 | "type": "STRING", 551 | "links": null, 552 | "tooltip": "Currently selected line based on select input" 553 | } 554 | ], 555 | "properties": { 556 | "Node name for S&R": "IF_DisplayText" 557 | }, 558 | "widgets_values": [ 559 | "", 560 | 0, 561 | "D:\\ComfyUI\\output\\memo_video_20241216-193105.mp4" 562 | ] 563 | }, 564 | { 565 | "id": 53, 566 | "type": "IF_LoadImagesS", 567 | "pos": [ 568 | -524.0703735351562, 569 | 187.803466796875 570 | ], 571 | "size": [ 572 | 342.7641906738281, 573 | 730 574 | ], 575 | "flags": {}, 576 | "order": 5, 577 | "mode": 0, 578 | "inputs": [], 579 | "outputs": [ 580 | { 581 | "name": "images", 582 | "type": "IMAGE", 583 | "links": [ 584 | 84 585 | ], 586 | "slot_index": 0, 587 | "shape": 6 588 | }, 589 | { 590 | "name": "masks", 591 | "type": "MASK", 592 | "links": [], 593 | "shape": 6 594 | }, 595 | { 596 | "name": "image_paths", 597 | "type": "STRING", 598 | "links": null, 599 | "shape": 6 600 | }, 601 | { 602 | "name": "filenames", 603 | "type": "STRING", 604 | "links": null, 605 | "shape": 6 606 | }, 607 | { 608 | "name": "count_str", 609 | "type": "STRING", 610 | "links": null, 611 | "shape": 6 612 | }, 613 | { 614 | "name": "count_int", 615 | "type": "INT", 616 | "links": null, 617 | "shape": 6 618 | } 619 | ], 620 | "properties": { 621 | "Node name for S&R": "IF_LoadImagesS" 622 | }, 623 | "widgets_values": [ 624 | "thb_00a04b95511409f306f73f6da484ae41.jpg", 625 | "Refresh Previews 🔄", 626 | "C:\\Users\\SOYYO\\Pictures\\people", 627 | 0, 628 | 30, 629 | "100", 630 | true, 631 | 31, 632 | true, 633 | "alphabetical", 634 | "none", 635 | "red", 636 | "image", 637 | "Select Folder 📂", 638 | "Backup Input 💾", 639 | "Restore Input ♻️" 640 | ] 641 | } 642 | ], 643 | "links": [ 644 | [ 645 | 84, 646 | 53, 647 | 0, 648 | 54, 649 | 0, 650 | "IMAGE" 651 | ], 652 | [ 653 | 89, 654 | 54, 655 | 0, 656 | 57, 657 | 0, 658 | "IMAGE" 659 | ], 660 | [ 661 | 90, 662 | 57, 663 | 0, 664 | 56, 665 | 0, 666 | "IMAGE" 667 | ], 668 | [ 669 | 100, 670 | 32, 671 | 0, 672 | 61, 673 | 2, 674 | "MODEL" 675 | ], 676 | [ 677 | 101, 678 | 32, 679 | 1, 680 | 61, 681 | 3, 682 | "MODEL" 683 | ], 684 | [ 685 | 102, 686 | 32, 687 | 2, 688 | 61, 689 | 4, 690 | "VAE" 691 | ], 692 | [ 693 | 103, 694 | 32, 695 | 3, 696 | 61, 697 | 5, 698 | "IMAGE_PROJ" 699 | ], 700 | [ 701 | 104, 702 | 32, 703 | 4, 704 | 61, 705 | 6, 706 | "AUDIO_PROJ" 707 | ], 708 | [ 709 | 105, 710 | 32, 711 | 5, 712 | 61, 713 | 7, 714 | "EMOTION_CLASSIFIER" 715 | ], 716 | [ 717 | 108, 718 | 61, 719 | 1, 720 | 31, 721 | 0, 722 | "STRING" 723 | ], 724 | [ 725 | 111, 726 | 60, 727 | 0, 728 | 61, 729 | 1, 730 | "AUDIO" 731 | ], 732 | [ 733 | 112, 734 | 57, 735 | 0, 736 | 61, 737 | 0, 738 | "IMAGE" 739 | ], 740 | [ 741 | 114, 742 | 61, 743 | 0, 744 | 67, 745 | 0, 746 | "STRING" 747 | ] 748 | ], 749 | "groups": [ 750 | { 751 | "id": 1, 752 | "title": "AVATAR", 753 | "bounding": [ 754 | 65.28616333007812, 755 | 85.40703582763672, 756 | 1330.9530029296875, 757 | 813.2778930664062 758 | ], 759 | "color": "#444", 760 | "font_size": 24, 761 | "flags": {} 762 | }, 763 | { 764 | "id": 2, 765 | "title": "MEMO", 766 | "bounding": [ 767 | -534.0703735351562, 768 | 84.20343780517578, 769 | 589.9998168945312, 770 | 879.60009765625 771 | ], 772 | "color": "#444", 773 | "font_size": 24, 774 | "flags": {} 775 | } 776 | ], 777 | "config": {}, 778 | "extra": { 779 | "ds": { 780 | "scale": 0.751314800901579, 781 | "offset": [ 782 | 654.7879901456771, 783 | 52.171584759852784 784 | ] 785 | }, 786 | "ue_links": [] 787 | }, 788 | "version": 0.4 789 | } -------------------------------------------------------------------------------- /workflow/IF_MemoAvatar_IF_Extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 79, 3 | "last_link_id": 206, 4 | "nodes": [ 5 | { 6 | "id": 56, 7 | "type": "PreviewImage", 8 | "pos": [ 9 | -174.07064819335938, 10 | 707.8036499023438 11 | ], 12 | "size": [ 13 | 211.03074645996094, 14 | 246 15 | ], 16 | "flags": {}, 17 | "order": 7, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "images", 22 | "type": "IMAGE", 23 | "link": 90 24 | } 25 | ], 26 | "outputs": [], 27 | "properties": { 28 | "Node name for S&R": "PreviewImage" 29 | }, 30 | "widgets_values": [] 31 | }, 32 | { 33 | "id": 50, 34 | "type": "Note", 35 | "pos": [ 36 | 446.33880615234375, 37 | 203.38783264160156 38 | ], 39 | "size": [ 40 | 331.2768249511719, 41 | 117.3055648803711 42 | ], 43 | "flags": {}, 44 | "order": 0, 45 | "mode": 0, 46 | "inputs": [], 47 | "outputs": [], 48 | "properties": {}, 49 | "widgets_values": [ 50 | "3090 worked with 384 faster\n512 slower \nreducing fps braked the video but needs testing and optimizations \nThis is a wrapper is using diffusers and same code as the Original Repo I only made a few changes it should work on 3090 or maybe even in lesser cards if chosen a smaller res" 51 | ], 52 | "color": "#432", 53 | "bgcolor": "#653" 54 | }, 55 | { 56 | "id": 54, 57 | "type": "ImageResizeKJ", 58 | "pos": [ 59 | -164.07064819335938, 60 | 157.803466796875 61 | ], 62 | "size": [ 63 | 210, 64 | 238 65 | ], 66 | "flags": { 67 | "collapsed": false 68 | }, 69 | "order": 5, 70 | "mode": 0, 71 | "inputs": [ 72 | { 73 | "name": "image", 74 | "type": "IMAGE", 75 | "link": 84 76 | }, 77 | { 78 | "name": "get_image_size", 79 | "type": "IMAGE", 80 | "link": null, 81 | "shape": 7 82 | }, 83 | { 84 | "name": "width_input", 85 | "type": "INT", 86 | "link": null, 87 | "widget": { 88 | "name": "width_input" 89 | }, 90 | "shape": 7 91 | }, 92 | { 93 | "name": "height_input", 94 | "type": "INT", 95 | "link": null, 96 | "widget": { 97 | "name": "height_input" 98 | }, 99 | "shape": 7 100 | } 101 | ], 102 | "outputs": [ 103 | { 104 | "name": "IMAGE", 105 | "type": "IMAGE", 106 | "links": [ 107 | 89 108 | ], 109 | "slot_index": 0 110 | }, 111 | { 112 | "name": "width", 113 | "type": "INT", 114 | "links": null 115 | }, 116 | { 117 | "name": "height", 118 | "type": "INT", 119 | "links": null 120 | } 121 | ], 122 | "properties": { 123 | "Node name for S&R": "ImageResizeKJ" 124 | }, 125 | "widgets_values": [ 126 | 1600, 127 | 1600, 128 | "lanczos", 129 | true, 130 | 2, 131 | 0, 132 | 0, 133 | "disabled" 134 | ] 135 | }, 136 | { 137 | "id": 59, 138 | "type": "Fast Groups Muter (rgthree)", 139 | "pos": [ 140 | 70.64138793945312, 141 | 937.8055419921875 142 | ], 143 | "size": [ 144 | 226.8000030517578, 145 | 106 146 | ], 147 | "flags": {}, 148 | "order": 1, 149 | "mode": 0, 150 | "inputs": [], 151 | "outputs": [ 152 | { 153 | "name": "OPT_CONNECTION", 154 | "type": "*", 155 | "links": null 156 | } 157 | ], 158 | "properties": { 159 | "matchColors": "", 160 | "matchTitle": "", 161 | "showNav": true, 162 | "sort": "position", 163 | "customSortAlphabet": "", 164 | "toggleRestriction": "default" 165 | } 166 | }, 167 | { 168 | "id": 67, 169 | "type": "IF_DisplayText", 170 | "pos": [ 171 | 858.1091918945312, 172 | 155.067626953125 173 | ], 174 | "size": [ 175 | 295.78253173828125, 176 | 228.17208862304688 177 | ], 178 | "flags": {}, 179 | "order": 9, 180 | "mode": 0, 181 | "inputs": [ 182 | { 183 | "name": "text", 184 | "type": "STRING", 185 | "link": 204, 186 | "widget": { 187 | "name": "text" 188 | } 189 | } 190 | ], 191 | "outputs": [ 192 | { 193 | "name": "text", 194 | "type": "STRING", 195 | "links": null, 196 | "tooltip": "Full text content" 197 | }, 198 | { 199 | "name": "text_list", 200 | "type": "STRING", 201 | "links": null, 202 | "shape": 6, 203 | "tooltip": "Individual lines as separate outputs" 204 | }, 205 | { 206 | "name": "count", 207 | "type": "INT", 208 | "links": null, 209 | "tooltip": "Total number of non-empty lines" 210 | }, 211 | { 212 | "name": "selected", 213 | "type": "STRING", 214 | "links": null, 215 | "tooltip": "Currently selected line based on select input" 216 | } 217 | ], 218 | "properties": { 219 | "Node name for S&R": "IF_DisplayText" 220 | }, 221 | "widgets_values": [ 222 | "", 223 | 0, 224 | "D:\\ComfyUI\\output\\memo_video_20241217-083345.mp4" 225 | ] 226 | }, 227 | { 228 | "id": 31, 229 | "type": "IF_DisplayText", 230 | "pos": [ 231 | 849.99169921875, 232 | 726.5304565429688 233 | ], 234 | "size": [ 235 | 295.78253173828125, 236 | 228.17208862304688 237 | ], 238 | "flags": {}, 239 | "order": 11, 240 | "mode": 0, 241 | "inputs": [ 242 | { 243 | "name": "text", 244 | "type": "STRING", 245 | "link": 205, 246 | "widget": { 247 | "name": "text" 248 | } 249 | } 250 | ], 251 | "outputs": [ 252 | { 253 | "name": "text", 254 | "type": "STRING", 255 | "links": null, 256 | "tooltip": "Full text content" 257 | }, 258 | { 259 | "name": "text_list", 260 | "type": "STRING", 261 | "links": null, 262 | "shape": 6, 263 | "tooltip": "Individual lines as separate outputs" 264 | }, 265 | { 266 | "name": "count", 267 | "type": "INT", 268 | "links": null, 269 | "tooltip": "Total number of non-empty lines" 270 | }, 271 | { 272 | "name": "selected", 273 | "type": "STRING", 274 | "links": null, 275 | "tooltip": "Currently selected line based on select input" 276 | } 277 | ], 278 | "properties": { 279 | "Node name for S&R": "IF_DisplayText" 280 | }, 281 | "widgets_values": [ 282 | "", 283 | 0, 284 | "✅ Video saved as memo_video_20241217-083345.mp4" 285 | ] 286 | }, 287 | { 288 | "id": 57, 289 | "type": "ImageCrop+", 290 | "pos": [ 291 | -174.07064819335938, 292 | 457.80340576171875 293 | ], 294 | "size": [ 295 | 210, 296 | 194 297 | ], 298 | "flags": {}, 299 | "order": 6, 300 | "mode": 0, 301 | "inputs": [ 302 | { 303 | "name": "image", 304 | "type": "IMAGE", 305 | "link": 89 306 | } 307 | ], 308 | "outputs": [ 309 | { 310 | "name": "IMAGE", 311 | "type": "IMAGE", 312 | "links": [ 313 | 90, 314 | 198 315 | ], 316 | "slot_index": 0 317 | }, 318 | { 319 | "name": "x", 320 | "type": "INT", 321 | "links": null 322 | }, 323 | { 324 | "name": "y", 325 | "type": "INT", 326 | "links": null 327 | } 328 | ], 329 | "properties": { 330 | "Node name for S&R": "ImageCrop+" 331 | }, 332 | "widgets_values": [ 333 | 1504, 334 | 1504, 335 | "center", 336 | -2, 337 | -41 338 | ] 339 | }, 340 | { 341 | "id": 60, 342 | "type": "LoadAudio", 343 | "pos": [ 344 | 70.73645782470703, 345 | 259.4642639160156 346 | ], 347 | "size": [ 348 | 315, 349 | 124 350 | ], 351 | "flags": {}, 352 | "order": 2, 353 | "mode": 0, 354 | "inputs": [], 355 | "outputs": [ 356 | { 357 | "name": "AUDIO", 358 | "type": "AUDIO", 359 | "links": [ 360 | 199 361 | ], 362 | "slot_index": 0 363 | } 364 | ], 365 | "properties": { 366 | "Node name for S&R": "LoadAudio" 367 | }, 368 | "widgets_values": [ 369 | "candy.wav", 370 | null, 371 | "" 372 | ] 373 | }, 374 | { 375 | "id": 79, 376 | "type": "IF_MemoAvatar", 377 | "pos": [ 378 | 415.24700927734375, 379 | 423.51800537109375 380 | ], 381 | "size": [ 382 | 400, 383 | 390 384 | ], 385 | "flags": {}, 386 | "order": 8, 387 | "mode": 0, 388 | "inputs": [ 389 | { 390 | "name": "image", 391 | "type": "IMAGE", 392 | "link": 198 393 | }, 394 | { 395 | "name": "audio", 396 | "type": "AUDIO", 397 | "link": 199 398 | }, 399 | { 400 | "name": "reference_net", 401 | "type": "MODEL", 402 | "link": 196 403 | }, 404 | { 405 | "name": "diffusion_net", 406 | "type": "MODEL", 407 | "link": 197 408 | }, 409 | { 410 | "name": "vae", 411 | "type": "VAE", 412 | "link": 200 413 | }, 414 | { 415 | "name": "image_proj", 416 | "type": "IMAGE_PROJ", 417 | "link": 201 418 | }, 419 | { 420 | "name": "audio_proj", 421 | "type": "AUDIO_PROJ", 422 | "link": 202 423 | }, 424 | { 425 | "name": "emotion_classifier", 426 | "type": "EMOTION_CLASSIFIER", 427 | "link": 203 428 | } 429 | ], 430 | "outputs": [ 431 | { 432 | "name": "video_path", 433 | "type": "STRING", 434 | "links": [ 435 | 204, 436 | 206 437 | ], 438 | "slot_index": 0 439 | }, 440 | { 441 | "name": "status", 442 | "type": "STRING", 443 | "links": [ 444 | 205 445 | ], 446 | "slot_index": 1 447 | } 448 | ], 449 | "properties": { 450 | "Node name for S&R": "IF_MemoAvatar" 451 | }, 452 | "widgets_values": [ 453 | 512, 454 | 16, 455 | 30, 456 | 20, 457 | 3.5, 458 | 2047, 459 | "randomize", 460 | "memo_video", 461 | { 462 | "serialize": false, 463 | "size": [ 464 | 256, 465 | 256 466 | ] 467 | } 468 | ] 469 | }, 470 | { 471 | "id": 68, 472 | "type": "VHS_LoadVideoPath", 473 | "pos": [ 474 | 880.303466796875, 475 | 424.2923278808594 476 | ], 477 | "size": [ 478 | 291.62481689453125, 479 | 238 480 | ], 481 | "flags": {}, 482 | "order": 10, 483 | "mode": 0, 484 | "inputs": [ 485 | { 486 | "name": "meta_batch", 487 | "type": "VHS_BatchManager", 488 | "link": null, 489 | "shape": 7 490 | }, 491 | { 492 | "name": "vae", 493 | "type": "VAE", 494 | "link": null, 495 | "shape": 7 496 | }, 497 | { 498 | "name": "video", 499 | "type": "STRING", 500 | "link": 206, 501 | "widget": { 502 | "name": "video" 503 | } 504 | } 505 | ], 506 | "outputs": [ 507 | { 508 | "name": "IMAGE", 509 | "type": "IMAGE", 510 | "links": null 511 | }, 512 | { 513 | "name": "frame_count", 514 | "type": "INT", 515 | "links": null 516 | }, 517 | { 518 | "name": "audio", 519 | "type": "AUDIO", 520 | "links": null 521 | }, 522 | { 523 | "name": "video_info", 524 | "type": "VHS_VIDEOINFO", 525 | "links": null 526 | } 527 | ], 528 | "properties": { 529 | "Node name for S&R": "VHS_LoadVideoPath" 530 | }, 531 | "widgets_values": { 532 | "video": "", 533 | "force_rate": 0, 534 | "force_size": "Disabled", 535 | "custom_width": 512, 536 | "custom_height": 512, 537 | "frame_load_cap": 0, 538 | "skip_first_frames": 0, 539 | "select_every_nth": 1, 540 | "videopreview": { 541 | "hidden": false, 542 | "paused": false, 543 | "params": { 544 | "force_rate": 0, 545 | "frame_load_cap": 0, 546 | "skip_first_frames": 0, 547 | "select_every_nth": 1 548 | }, 549 | "muted": false 550 | } 551 | } 552 | }, 553 | { 554 | "id": 74, 555 | "type": "IF_MemoCheckpointLoader", 556 | "pos": [ 557 | 80.47496032714844, 558 | 463.7469482421875 559 | ], 560 | "size": [ 561 | 315, 562 | 158 563 | ], 564 | "flags": {}, 565 | "order": 3, 566 | "mode": 0, 567 | "inputs": [], 568 | "outputs": [ 569 | { 570 | "name": "reference_net", 571 | "type": "MODEL", 572 | "links": [ 573 | 196 574 | ], 575 | "slot_index": 0 576 | }, 577 | { 578 | "name": "diffusion_net", 579 | "type": "MODEL", 580 | "links": [ 581 | 197 582 | ], 583 | "slot_index": 1 584 | }, 585 | { 586 | "name": "vae", 587 | "type": "VAE", 588 | "links": [ 589 | 200 590 | ], 591 | "slot_index": 2 592 | }, 593 | { 594 | "name": "image_proj", 595 | "type": "IMAGE_PROJ", 596 | "links": [ 597 | 201 598 | ], 599 | "slot_index": 3 600 | }, 601 | { 602 | "name": "audio_proj", 603 | "type": "AUDIO_PROJ", 604 | "links": [ 605 | 202 606 | ], 607 | "slot_index": 4 608 | }, 609 | { 610 | "name": "emotion_classifier", 611 | "type": "EMOTION_CLASSIFIER", 612 | "links": [ 613 | 203 614 | ], 615 | "slot_index": 5 616 | } 617 | ], 618 | "properties": { 619 | "Node name for S&R": "IF_MemoCheckpointLoader" 620 | }, 621 | "widgets_values": [ 622 | true 623 | ] 624 | }, 625 | { 626 | "id": 53, 627 | "type": "IF_LoadImagesS", 628 | "pos": [ 629 | -524.0703735351562, 630 | 187.803466796875 631 | ], 632 | "size": [ 633 | 342.7641906738281, 634 | 730 635 | ], 636 | "flags": {}, 637 | "order": 4, 638 | "mode": 0, 639 | "inputs": [], 640 | "outputs": [ 641 | { 642 | "name": "images", 643 | "type": "IMAGE", 644 | "links": [ 645 | 84 646 | ], 647 | "slot_index": 0, 648 | "shape": 6 649 | }, 650 | { 651 | "name": "masks", 652 | "type": "MASK", 653 | "links": [], 654 | "shape": 6 655 | }, 656 | { 657 | "name": "image_paths", 658 | "type": "STRING", 659 | "links": null, 660 | "shape": 6 661 | }, 662 | { 663 | "name": "filenames", 664 | "type": "STRING", 665 | "links": null, 666 | "shape": 6 667 | }, 668 | { 669 | "name": "count_str", 670 | "type": "STRING", 671 | "links": null, 672 | "shape": 6 673 | }, 674 | { 675 | "name": "count_int", 676 | "type": "INT", 677 | "links": null, 678 | "shape": 6 679 | } 680 | ], 681 | "properties": { 682 | "Node name for S&R": "IF_LoadImagesS" 683 | }, 684 | "widgets_values": [ 685 | "thb_4bc166bd2e2e609360a884982b812965.jpg", 686 | "Refresh Previews 🔄", 687 | "C:\\Users\\SOYYO\\Pictures\\people", 688 | 0, 689 | 48, 690 | "100", 691 | true, 692 | 48, 693 | true, 694 | "alphabetical", 695 | "none", 696 | "red", 697 | "image", 698 | "Select Folder 📂", 699 | "Backup Input 💾", 700 | "Restore Input ♻️" 701 | ] 702 | } 703 | ], 704 | "links": [ 705 | [ 706 | 84, 707 | 53, 708 | 0, 709 | 54, 710 | 0, 711 | "IMAGE" 712 | ], 713 | [ 714 | 89, 715 | 54, 716 | 0, 717 | 57, 718 | 0, 719 | "IMAGE" 720 | ], 721 | [ 722 | 90, 723 | 57, 724 | 0, 725 | 56, 726 | 0, 727 | "IMAGE" 728 | ], 729 | [ 730 | 196, 731 | 74, 732 | 0, 733 | 79, 734 | 2, 735 | "MODEL" 736 | ], 737 | [ 738 | 197, 739 | 74, 740 | 1, 741 | 79, 742 | 3, 743 | "MODEL" 744 | ], 745 | [ 746 | 198, 747 | 57, 748 | 0, 749 | 79, 750 | 0, 751 | "IMAGE" 752 | ], 753 | [ 754 | 199, 755 | 60, 756 | 0, 757 | 79, 758 | 1, 759 | "AUDIO" 760 | ], 761 | [ 762 | 200, 763 | 74, 764 | 2, 765 | 79, 766 | 4, 767 | "VAE" 768 | ], 769 | [ 770 | 201, 771 | 74, 772 | 3, 773 | 79, 774 | 5, 775 | "IMAGE_PROJ" 776 | ], 777 | [ 778 | 202, 779 | 74, 780 | 4, 781 | 79, 782 | 6, 783 | "AUDIO_PROJ" 784 | ], 785 | [ 786 | 203, 787 | 74, 788 | 5, 789 | 79, 790 | 7, 791 | "EMOTION_CLASSIFIER" 792 | ], 793 | [ 794 | 204, 795 | 79, 796 | 0, 797 | 67, 798 | 0, 799 | "STRING" 800 | ], 801 | [ 802 | 205, 803 | 79, 804 | 1, 805 | 31, 806 | 0, 807 | "STRING" 808 | ], 809 | [ 810 | 206, 811 | 79, 812 | 0, 813 | 68, 814 | 2, 815 | "STRING" 816 | ] 817 | ], 818 | "groups": [ 819 | { 820 | "id": 1, 821 | "title": "AVATAR", 822 | "bounding": [ 823 | 65.28616333007812, 824 | 85.40703582763672, 825 | 1330.9530029296875, 826 | 813.2778930664062 827 | ], 828 | "color": "#444", 829 | "font_size": 24, 830 | "flags": {} 831 | }, 832 | { 833 | "id": 2, 834 | "title": "MEMO", 835 | "bounding": [ 836 | -534.0703735351562, 837 | 84.20343780517578, 838 | 589.9998168945312, 839 | 879.60009765625 840 | ], 841 | "color": "#444", 842 | "font_size": 24, 843 | "flags": {} 844 | } 845 | ], 846 | "config": {}, 847 | "extra": { 848 | "ds": { 849 | "scale": 1, 850 | "offset": [ 851 | 566.6587908244418, 852 | -17.309013215858414 853 | ] 854 | }, 855 | "ue_links": [] 856 | }, 857 | "version": 0.4 858 | } -------------------------------------------------------------------------------- /workflow/IF_MemoAvatar_simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 82, 3 | "last_link_id": 209, 4 | "nodes": [ 5 | { 6 | "id": 56, 7 | "type": "PreviewImage", 8 | "pos": [ 9 | -174.07064819335938, 10 | 707.8036499023438 11 | ], 12 | "size": [ 13 | 211.03074645996094, 14 | 246 15 | ], 16 | "flags": {}, 17 | "order": 7, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "images", 22 | "type": "IMAGE", 23 | "link": 90 24 | } 25 | ], 26 | "outputs": [], 27 | "properties": { 28 | "Node name for S&R": "PreviewImage" 29 | }, 30 | "widgets_values": [] 31 | }, 32 | { 33 | "id": 50, 34 | "type": "Note", 35 | "pos": [ 36 | 446.33880615234375, 37 | 203.38783264160156 38 | ], 39 | "size": [ 40 | 331.2768249511719, 41 | 117.3055648803711 42 | ], 43 | "flags": {}, 44 | "order": 0, 45 | "mode": 0, 46 | "inputs": [], 47 | "outputs": [], 48 | "properties": {}, 49 | "widgets_values": [ 50 | "3090 worked with 384 faster\n512 slower \nreducing fps braked the video but needs testing and optimizations \nThis is a wrapper is using diffusers and same code as the Original Repo I only made a few changes it should work on 3090 or maybe even in lesser cards if chosen a smaller res" 51 | ], 52 | "color": "#432", 53 | "bgcolor": "#653" 54 | }, 55 | { 56 | "id": 59, 57 | "type": "Fast Groups Muter (rgthree)", 58 | "pos": [ 59 | 70.64138793945312, 60 | 937.8055419921875 61 | ], 62 | "size": [ 63 | 226.8000030517578, 64 | 106 65 | ], 66 | "flags": {}, 67 | "order": 1, 68 | "mode": 0, 69 | "inputs": [], 70 | "outputs": [ 71 | { 72 | "name": "OPT_CONNECTION", 73 | "type": "*", 74 | "links": null 75 | } 76 | ], 77 | "properties": { 78 | "matchColors": "", 79 | "matchTitle": "", 80 | "showNav": true, 81 | "sort": "position", 82 | "customSortAlphabet": "", 83 | "toggleRestriction": "default" 84 | } 85 | }, 86 | { 87 | "id": 57, 88 | "type": "ImageCrop+", 89 | "pos": [ 90 | -174.07064819335938, 91 | 457.80340576171875 92 | ], 93 | "size": [ 94 | 210, 95 | 194 96 | ], 97 | "flags": {}, 98 | "order": 6, 99 | "mode": 0, 100 | "inputs": [ 101 | { 102 | "name": "image", 103 | "type": "IMAGE", 104 | "link": 89 105 | } 106 | ], 107 | "outputs": [ 108 | { 109 | "name": "IMAGE", 110 | "type": "IMAGE", 111 | "links": [ 112 | 90, 113 | 198 114 | ], 115 | "slot_index": 0 116 | }, 117 | { 118 | "name": "x", 119 | "type": "INT", 120 | "links": null 121 | }, 122 | { 123 | "name": "y", 124 | "type": "INT", 125 | "links": null 126 | } 127 | ], 128 | "properties": { 129 | "Node name for S&R": "ImageCrop+" 130 | }, 131 | "widgets_values": [ 132 | 1504, 133 | 1504, 134 | "center", 135 | -2, 136 | -41 137 | ] 138 | }, 139 | { 140 | "id": 60, 141 | "type": "LoadAudio", 142 | "pos": [ 143 | 70.73645782470703, 144 | 259.4642639160156 145 | ], 146 | "size": [ 147 | 315, 148 | 124 149 | ], 150 | "flags": {}, 151 | "order": 2, 152 | "mode": 0, 153 | "inputs": [], 154 | "outputs": [ 155 | { 156 | "name": "AUDIO", 157 | "type": "AUDIO", 158 | "links": [ 159 | 199 160 | ], 161 | "slot_index": 0 162 | } 163 | ], 164 | "properties": { 165 | "Node name for S&R": "LoadAudio" 166 | }, 167 | "widgets_values": [ 168 | "candy.wav", 169 | null, 170 | "" 171 | ] 172 | }, 173 | { 174 | "id": 68, 175 | "type": "VHS_LoadVideoPath", 176 | "pos": [ 177 | 880.303466796875, 178 | 424.2923278808594 179 | ], 180 | "size": [ 181 | 291.62481689453125, 182 | 238 183 | ], 184 | "flags": {}, 185 | "order": 9, 186 | "mode": 0, 187 | "inputs": [ 188 | { 189 | "name": "meta_batch", 190 | "type": "VHS_BatchManager", 191 | "link": null, 192 | "shape": 7 193 | }, 194 | { 195 | "name": "vae", 196 | "type": "VAE", 197 | "link": null, 198 | "shape": 7 199 | }, 200 | { 201 | "name": "video", 202 | "type": "STRING", 203 | "link": 206, 204 | "widget": { 205 | "name": "video" 206 | } 207 | } 208 | ], 209 | "outputs": [ 210 | { 211 | "name": "IMAGE", 212 | "type": "IMAGE", 213 | "links": null 214 | }, 215 | { 216 | "name": "frame_count", 217 | "type": "INT", 218 | "links": null 219 | }, 220 | { 221 | "name": "audio", 222 | "type": "AUDIO", 223 | "links": null 224 | }, 225 | { 226 | "name": "video_info", 227 | "type": "VHS_VIDEOINFO", 228 | "links": null 229 | } 230 | ], 231 | "properties": { 232 | "Node name for S&R": "VHS_LoadVideoPath" 233 | }, 234 | "widgets_values": { 235 | "video": "", 236 | "force_rate": 0, 237 | "force_size": "Disabled", 238 | "custom_width": 512, 239 | "custom_height": 512, 240 | "frame_load_cap": 0, 241 | "skip_first_frames": 0, 242 | "select_every_nth": 1, 243 | "videopreview": { 244 | "hidden": false, 245 | "paused": false, 246 | "params": { 247 | "force_rate": 0, 248 | "frame_load_cap": 0, 249 | "skip_first_frames": 0, 250 | "select_every_nth": 1 251 | }, 252 | "muted": false 253 | } 254 | } 255 | }, 256 | { 257 | "id": 74, 258 | "type": "IF_MemoCheckpointLoader", 259 | "pos": [ 260 | 80.47496032714844, 261 | 463.7469482421875 262 | ], 263 | "size": [ 264 | 315, 265 | 158 266 | ], 267 | "flags": {}, 268 | "order": 3, 269 | "mode": 0, 270 | "inputs": [], 271 | "outputs": [ 272 | { 273 | "name": "reference_net", 274 | "type": "MODEL", 275 | "links": [ 276 | 196 277 | ], 278 | "slot_index": 0 279 | }, 280 | { 281 | "name": "diffusion_net", 282 | "type": "MODEL", 283 | "links": [ 284 | 197 285 | ], 286 | "slot_index": 1 287 | }, 288 | { 289 | "name": "vae", 290 | "type": "VAE", 291 | "links": [ 292 | 200 293 | ], 294 | "slot_index": 2 295 | }, 296 | { 297 | "name": "image_proj", 298 | "type": "IMAGE_PROJ", 299 | "links": [ 300 | 201 301 | ], 302 | "slot_index": 3 303 | }, 304 | { 305 | "name": "audio_proj", 306 | "type": "AUDIO_PROJ", 307 | "links": [ 308 | 202 309 | ], 310 | "slot_index": 4 311 | }, 312 | { 313 | "name": "emotion_classifier", 314 | "type": "EMOTION_CLASSIFIER", 315 | "links": [ 316 | 203 317 | ], 318 | "slot_index": 5 319 | } 320 | ], 321 | "properties": { 322 | "Node name for S&R": "IF_MemoCheckpointLoader" 323 | }, 324 | "widgets_values": [ 325 | true 326 | ] 327 | }, 328 | { 329 | "id": 54, 330 | "type": "ImageResizeKJ", 331 | "pos": [ 332 | -164.07064819335938, 333 | 157.803466796875 334 | ], 335 | "size": [ 336 | 210, 337 | 238 338 | ], 339 | "flags": { 340 | "collapsed": false 341 | }, 342 | "order": 5, 343 | "mode": 0, 344 | "inputs": [ 345 | { 346 | "name": "image", 347 | "type": "IMAGE", 348 | "link": 207 349 | }, 350 | { 351 | "name": "get_image_size", 352 | "type": "IMAGE", 353 | "link": null, 354 | "shape": 7 355 | }, 356 | { 357 | "name": "width_input", 358 | "type": "INT", 359 | "link": null, 360 | "widget": { 361 | "name": "width_input" 362 | }, 363 | "shape": 7 364 | }, 365 | { 366 | "name": "height_input", 367 | "type": "INT", 368 | "link": null, 369 | "widget": { 370 | "name": "height_input" 371 | }, 372 | "shape": 7 373 | } 374 | ], 375 | "outputs": [ 376 | { 377 | "name": "IMAGE", 378 | "type": "IMAGE", 379 | "links": [ 380 | 89 381 | ], 382 | "slot_index": 0 383 | }, 384 | { 385 | "name": "width", 386 | "type": "INT", 387 | "links": null 388 | }, 389 | { 390 | "name": "height", 391 | "type": "INT", 392 | "links": null 393 | } 394 | ], 395 | "properties": { 396 | "Node name for S&R": "ImageResizeKJ" 397 | }, 398 | "widgets_values": [ 399 | 1600, 400 | 1600, 401 | "lanczos", 402 | true, 403 | 2, 404 | 0, 405 | 0, 406 | "disabled" 407 | ] 408 | }, 409 | { 410 | "id": 80, 411 | "type": "LoadImage", 412 | "pos": [ 413 | -498.6588134765625, 414 | 202.10902404785156 415 | ], 416 | "size": [ 417 | 315, 418 | 314 419 | ], 420 | "flags": {}, 421 | "order": 4, 422 | "mode": 0, 423 | "inputs": [], 424 | "outputs": [ 425 | { 426 | "name": "IMAGE", 427 | "type": "IMAGE", 428 | "links": [ 429 | 207 430 | ] 431 | }, 432 | { 433 | "name": "MASK", 434 | "type": "MASK", 435 | "links": null 436 | } 437 | ], 438 | "properties": { 439 | "Node name for S&R": "LoadImage" 440 | }, 441 | "widgets_values": [ 442 | "candy@2x.png", 443 | "image" 444 | ] 445 | }, 446 | { 447 | "id": 81, 448 | "type": "ShowText|pysssss", 449 | "pos": [ 450 | 864.541259765625, 451 | 733.3090209960938 452 | ], 453 | "size": [ 454 | 315, 455 | 58 456 | ], 457 | "flags": {}, 458 | "order": 11, 459 | "mode": 0, 460 | "inputs": [ 461 | { 462 | "name": "text", 463 | "type": "STRING", 464 | "link": 208, 465 | "widget": { 466 | "name": "text" 467 | } 468 | } 469 | ], 470 | "outputs": [ 471 | { 472 | "name": "STRING", 473 | "type": "STRING", 474 | "links": null, 475 | "shape": 6 476 | } 477 | ], 478 | "properties": { 479 | "Node name for S&R": "ShowText|pysssss" 480 | }, 481 | "widgets_values": [ 482 | "" 483 | ] 484 | }, 485 | { 486 | "id": 82, 487 | "type": "ShowText|pysssss", 488 | "pos": [ 489 | 866.1412353515625, 490 | 292.5090026855469 491 | ], 492 | "size": [ 493 | 315, 494 | 58 495 | ], 496 | "flags": {}, 497 | "order": 10, 498 | "mode": 0, 499 | "inputs": [ 500 | { 501 | "name": "text", 502 | "type": "STRING", 503 | "link": 209, 504 | "widget": { 505 | "name": "text" 506 | } 507 | } 508 | ], 509 | "outputs": [ 510 | { 511 | "name": "STRING", 512 | "type": "STRING", 513 | "links": null, 514 | "shape": 6 515 | } 516 | ], 517 | "properties": { 518 | "Node name for S&R": "ShowText|pysssss" 519 | }, 520 | "widgets_values": [ 521 | "" 522 | ] 523 | }, 524 | { 525 | "id": 79, 526 | "type": "IF_MemoAvatar", 527 | "pos": [ 528 | 415.24700927734375, 529 | 423.51800537109375 530 | ], 531 | "size": [ 532 | 400, 533 | 390 534 | ], 535 | "flags": {}, 536 | "order": 8, 537 | "mode": 0, 538 | "inputs": [ 539 | { 540 | "name": "image", 541 | "type": "IMAGE", 542 | "link": 198 543 | }, 544 | { 545 | "name": "audio", 546 | "type": "AUDIO", 547 | "link": 199 548 | }, 549 | { 550 | "name": "reference_net", 551 | "type": "MODEL", 552 | "link": 196 553 | }, 554 | { 555 | "name": "diffusion_net", 556 | "type": "MODEL", 557 | "link": 197 558 | }, 559 | { 560 | "name": "vae", 561 | "type": "VAE", 562 | "link": 200 563 | }, 564 | { 565 | "name": "image_proj", 566 | "type": "IMAGE_PROJ", 567 | "link": 201 568 | }, 569 | { 570 | "name": "audio_proj", 571 | "type": "AUDIO_PROJ", 572 | "link": 202 573 | }, 574 | { 575 | "name": "emotion_classifier", 576 | "type": "EMOTION_CLASSIFIER", 577 | "link": 203 578 | } 579 | ], 580 | "outputs": [ 581 | { 582 | "name": "video_path", 583 | "type": "STRING", 584 | "links": [ 585 | 206, 586 | 209 587 | ], 588 | "slot_index": 0 589 | }, 590 | { 591 | "name": "status", 592 | "type": "STRING", 593 | "links": [ 594 | 208 595 | ], 596 | "slot_index": 1 597 | } 598 | ], 599 | "properties": { 600 | "Node name for S&R": "IF_MemoAvatar" 601 | }, 602 | "widgets_values": [ 603 | 384, 604 | 16, 605 | 30, 606 | 20, 607 | 3.5, 608 | 2047, 609 | "randomize", 610 | "memo_video", 611 | { 612 | "serialize": false, 613 | "size": [ 614 | 256, 615 | 256 616 | ] 617 | } 618 | ] 619 | } 620 | ], 621 | "links": [ 622 | [ 623 | 89, 624 | 54, 625 | 0, 626 | 57, 627 | 0, 628 | "IMAGE" 629 | ], 630 | [ 631 | 90, 632 | 57, 633 | 0, 634 | 56, 635 | 0, 636 | "IMAGE" 637 | ], 638 | [ 639 | 196, 640 | 74, 641 | 0, 642 | 79, 643 | 2, 644 | "MODEL" 645 | ], 646 | [ 647 | 197, 648 | 74, 649 | 1, 650 | 79, 651 | 3, 652 | "MODEL" 653 | ], 654 | [ 655 | 198, 656 | 57, 657 | 0, 658 | 79, 659 | 0, 660 | "IMAGE" 661 | ], 662 | [ 663 | 199, 664 | 60, 665 | 0, 666 | 79, 667 | 1, 668 | "AUDIO" 669 | ], 670 | [ 671 | 200, 672 | 74, 673 | 2, 674 | 79, 675 | 4, 676 | "VAE" 677 | ], 678 | [ 679 | 201, 680 | 74, 681 | 3, 682 | 79, 683 | 5, 684 | "IMAGE_PROJ" 685 | ], 686 | [ 687 | 202, 688 | 74, 689 | 4, 690 | 79, 691 | 6, 692 | "AUDIO_PROJ" 693 | ], 694 | [ 695 | 203, 696 | 74, 697 | 5, 698 | 79, 699 | 7, 700 | "EMOTION_CLASSIFIER" 701 | ], 702 | [ 703 | 206, 704 | 79, 705 | 0, 706 | 68, 707 | 2, 708 | "STRING" 709 | ], 710 | [ 711 | 207, 712 | 80, 713 | 0, 714 | 54, 715 | 0, 716 | "IMAGE" 717 | ], 718 | [ 719 | 208, 720 | 79, 721 | 1, 722 | 81, 723 | 0, 724 | "STRING" 725 | ], 726 | [ 727 | 209, 728 | 79, 729 | 0, 730 | 82, 731 | 0, 732 | "STRING" 733 | ] 734 | ], 735 | "groups": [ 736 | { 737 | "id": 1, 738 | "title": "AVATAR", 739 | "bounding": [ 740 | 65.28616333007812, 741 | 85.40703582763672, 742 | 1330.9530029296875, 743 | 813.2778930664062 744 | ], 745 | "color": "#444", 746 | "font_size": 24, 747 | "flags": {} 748 | }, 749 | { 750 | "id": 2, 751 | "title": "MEMO", 752 | "bounding": [ 753 | -534.0703735351562, 754 | 84.20343780517578, 755 | 589.9998168945312, 756 | 879.60009765625 757 | ], 758 | "color": "#444", 759 | "font_size": 24, 760 | "flags": {} 761 | } 762 | ], 763 | "config": {}, 764 | "extra": { 765 | "ds": { 766 | "scale": 1, 767 | "offset": [ 768 | 565.0588152385043, 769 | -17.309013215858414 770 | ] 771 | }, 772 | "ue_links": [] 773 | }, 774 | "version": 0.4 775 | } --------------------------------------------------------------------------------