├── .github
├── FUNDING.yml
└── workflows
│ └── publish_to_registry.yml
├── IF_MemoAvatar.py
├── IF_MemoCheckpointLoader.py
├── LICENSE
├── README.md
├── __init__.py
├── configs
└── inference.yaml
├── examples
├── candy.wav
├── candy@2x.png
├── dicaprio.jpg
└── speech.wav
├── memo
├── __init__.py
├── models
│ ├── __init__.py
│ ├── attention.py
│ ├── attention_processor.py
│ ├── audio_proj.py
│ ├── emotion_classifier.py
│ ├── image_proj.py
│ ├── motion_module.py
│ ├── normalization.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ ├── transformer_3d.py
│ ├── unet_2d_blocks.py
│ ├── unet_2d_condition.py
│ ├── unet_3d.py
│ ├── unet_3d_blocks.py
│ └── wav2vec.py
├── pipelines
│ ├── __init__.py
│ └── video_pipeline.py
└── utils
│ ├── __init__.py
│ ├── audio_utils.py
│ └── vision_utils.py
├── memo_model_manager.py
├── pyproject.toml
├── requirements.txt
├── web
└── js
│ └── IF_MemoAvatar.js
└── workflow
├── IF_MemoAvatar.json
├── IF_MemoAvatar_IF_Extensions.json
└── IF_MemoAvatar_simple.json
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: ImpactFrames
2 | patreon: ImpactFrames
3 | ko_fi: impactframes
4 |
--------------------------------------------------------------------------------
/.github/workflows/publish_to_registry.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "pyproject.toml"
9 |
10 | jobs:
11 | publish-node:
12 | name: Publish Custom Node to registry
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Check out code
16 | uses: actions/checkout@v4
17 | - name: Publish Custom Node
18 | uses: Comfy-Org/publish-node-action@main
19 | with:
20 | ## Add your own personal access token to your Github Repository secrets and reference it here.
21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
22 |
--------------------------------------------------------------------------------
/IF_MemoAvatar.py:
--------------------------------------------------------------------------------
1 | #IF_MemoAvatar.py
2 | import os
3 | import torch
4 | import numpy as np
5 | import torchaudio
6 | from PIL import Image
7 | import logging
8 | from tqdm import tqdm
9 | import time
10 | from contextlib import contextmanager
11 |
12 | import folder_paths
13 | import comfy.model_management
14 | from diffusers import FlowMatchEulerDiscreteScheduler
15 | from diffusers.utils import is_xformers_available
16 |
17 | from memo.pipelines.video_pipeline import VideoPipeline
18 | from memo.utils.audio_utils import extract_audio_emotion_labels, preprocess_audio, resample_audio
19 | from memo.utils.vision_utils import preprocess_image, tensor_to_video
20 | from memo_model_manager import MemoModelManager
21 |
22 | logger = logging.getLogger("memo")
23 |
24 | class IF_MemoAvatar:
25 | @classmethod
26 | def INPUT_TYPES(s):
27 | return {
28 | "required": {
29 | "image": ("IMAGE",),
30 | "audio": ("AUDIO",),
31 | "reference_net": ("MODEL",),
32 | "diffusion_net": ("MODEL",),
33 | "vae": ("VAE",),
34 | "image_proj": ("IMAGE_PROJ",),
35 | "audio_proj": ("AUDIO_PROJ",),
36 | "emotion_classifier": ("EMOTION_CLASSIFIER",),
37 | "resolution": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 8}),
38 | "num_frames_per_clip": ("INT", {"default": 16, "min": 1, "max": 32}),
39 | "fps": ("INT", {"default": 30, "min": 1, "max": 60}),
40 | "inference_steps": ("INT", {"default": 20, "min": 1, "max": 100}),
41 | "cfg_scale": ("FLOAT", {"default": 3.5, "min": 1.0, "max": 100.0}),
42 | "seed": ("INT", {"default": 42}),
43 | "output_name": ("STRING", {"default": "memo_video"})
44 | }
45 | }
46 |
47 | RETURN_TYPES = ("STRING", "STRING")
48 | RETURN_NAMES = ("video_path", "status")
49 | FUNCTION = "generate"
50 | CATEGORY = "ImpactFrames💥🎞️/MemoAvatar"
51 |
52 | def __init__(self):
53 | self.device = comfy.model_management.get_torch_device()
54 | # Use bfloat16 if available, fallback to float16
55 | self.dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
56 |
57 | # Initialize model manager and get paths
58 | self.model_manager = MemoModelManager()
59 | self.paths = self.model_manager.get_model_paths()
60 |
61 |
62 | def generate(self, image, audio, reference_net, diffusion_net, vae, image_proj, audio_proj,
63 | emotion_classifier, resolution=512, num_frames_per_clip=16, fps=30,
64 | inference_steps=20, cfg_scale=3.5, seed=42, output_name="memo_video"):
65 | try:
66 | # Save video
67 | timestamp = time.strftime('%Y%m%d-%H%M%S')
68 | video_name = f"{output_name}_{timestamp}.mp4"
69 | output_dir = folder_paths.get_output_directory()
70 | video_path = os.path.join(output_dir, video_name)
71 |
72 | # Memory optimizations
73 | torch.cuda.empty_cache()
74 | if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
75 | autocast = torch.cuda.amp.autocast
76 | else:
77 | @contextmanager
78 | def autocast():
79 | yield
80 |
81 | num_init_past_frames = 2
82 | num_past_frames = 16
83 |
84 | # Save input image temporarily
85 | temp_dir = folder_paths.get_temp_directory()
86 | temp_image = os.path.join(temp_dir, f"ref_image_{time.time()}.png")
87 |
88 | try:
89 | # Convert ComfyUI image format to PIL
90 | if isinstance(image, torch.Tensor):
91 | image = image.cpu().numpy()
92 | if image.ndim == 4:
93 | image = image[0]
94 | image = Image.fromarray((image * 255).astype(np.uint8))
95 | image.save(temp_image)
96 | face_models_path = self.paths["face_models"]
97 | print(f"face_models path: {face_models_path}")
98 | # Process image with our models
99 | pixel_values, face_emb = preprocess_image(
100 | self.paths["face_models"], # face_analysis_model
101 | temp_image, # image_path
102 | resolution # image_size
103 | )
104 | finally:
105 | if os.path.exists(temp_image):
106 | os.remove(temp_image)
107 |
108 | # Save audio temporarily
109 | temp_dir = folder_paths.get_temp_directory()
110 | temp_audio = os.path.join(temp_dir, f"temp_audio_{time.time()}.wav")
111 |
112 | try:
113 | # Convert 3D tensor to 2D if necessary
114 | waveform = audio["waveform"]
115 | if waveform.ndim == 3:
116 | waveform = waveform.squeeze(0) # Remove batch dimension
117 |
118 | # Save the audio tensor to a temporary WAV file
119 | torchaudio.save(temp_audio, waveform, audio["sample_rate"])
120 |
121 | # Set up audio cache directory
122 | cache_dir = os.path.join(folder_paths.get_temp_directory(), "memo_audio_cache")
123 | os.makedirs(cache_dir, exist_ok=True)
124 |
125 | resampled_path = os.path.join(cache_dir, f"resampled_{time.time()}-16k.wav")
126 | resampled_path = resample_audio(temp_audio, resampled_path)
127 |
128 | # Process audio
129 | audio_emb, audio_length = preprocess_audio(
130 | wav_path=resampled_path,
131 | num_generated_frames_per_clip=num_frames_per_clip,
132 | fps=fps,
133 | wav2vec_model=self.paths["wav2vec"],
134 | vocal_separator_model=self.paths["vocal_separator"],
135 | cache_dir=cache_dir,
136 | device=str(self.device)
137 | )
138 |
139 | # Extract emotion
140 | audio_emotion, num_emotion_classes = extract_audio_emotion_labels(
141 | model=self.paths["memo_base"],
142 | wav_path=resampled_path,
143 | emotion2vec_model=self.paths["emotion2vec"],
144 | audio_length=audio_length,
145 | device=str(self.device)
146 | )
147 |
148 | # Model optimizations
149 | vae.requires_grad_(False).eval()
150 | reference_net.requires_grad_(False).eval()
151 | diffusion_net.requires_grad_(False).eval()
152 | image_proj.requires_grad_(False).eval()
153 | audio_proj.requires_grad_(False).eval()
154 |
155 | # Enable memory efficient attention (Optional)
156 | if is_xformers_available():
157 | try:
158 | reference_net.enable_xformers_memory_efficient_attention()
159 | diffusion_net.enable_xformers_memory_efficient_attention()
160 | except Exception as e:
161 | logger.warning(
162 | f"Could not enable memory efficient attention for xformers: {e}."
163 | "Do you have xformers installed? "
164 | "If you do, please check your xformers installation"
165 | )
166 |
167 | # Create pipeline with optimizations
168 | noise_scheduler = FlowMatchEulerDiscreteScheduler()
169 | with torch.inference_mode():
170 | pipeline = VideoPipeline(
171 | vae=vae,
172 | reference_net=reference_net,
173 | diffusion_net=diffusion_net,
174 | scheduler=noise_scheduler,
175 | image_proj=image_proj,
176 | )
177 | pipeline.to(device=self.device, dtype=self.dtype)
178 |
179 | # Generate video frames with memory optimizations
180 | video_frames = []
181 | num_clips = audio_emb.shape[0] // num_frames_per_clip
182 | generator = torch.Generator(device=self.device).manual_seed(seed)
183 |
184 | for t in tqdm(range(num_clips), desc="Generating video clips"):
185 | # Clear cache at the start of each iteration
186 | if torch.cuda.is_available():
187 | torch.cuda.empty_cache()
188 |
189 | if len(video_frames) == 0:
190 | past_frames = pixel_values.repeat(num_init_past_frames, 1, 1, 1)
191 | past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device)
192 | pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0)
193 | else:
194 | past_frames = video_frames[-1][0]
195 | past_frames = past_frames.permute(1, 0, 2, 3)
196 | past_frames = past_frames[0 - num_past_frames:]
197 | past_frames = past_frames * 2.0 - 1.0
198 | past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device)
199 | pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0)
200 |
201 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
202 |
203 | # Process audio in smaller chunks if needed
204 | audio_tensor = (
205 | audio_emb[
206 | t * num_frames_per_clip : min(
207 | (t + 1) * num_frames_per_clip, audio_emb.shape[0]
208 | )
209 | ]
210 | .unsqueeze(0)
211 | .to(device=audio_proj.device, dtype=audio_proj.dtype)
212 | )
213 |
214 | with torch.inference_mode():
215 | audio_tensor = audio_proj(audio_tensor)
216 |
217 | audio_emotion_tensor = audio_emotion[
218 | t * num_frames_per_clip : min(
219 | (t + 1) * num_frames_per_clip, audio_emb.shape[0]
220 | )
221 | ]
222 |
223 | pipeline_output = pipeline(
224 | ref_image=pixel_values_ref_img,
225 | audio_tensor=audio_tensor,
226 | audio_emotion=audio_emotion_tensor,
227 | emotion_class_num=num_emotion_classes,
228 | face_emb=face_emb,
229 | width=resolution,
230 | height=resolution,
231 | video_length=num_frames_per_clip,
232 | num_inference_steps=inference_steps,
233 | guidance_scale=cfg_scale,
234 | generator=generator,
235 | )
236 |
237 | video_frames.append(pipeline_output.videos)
238 |
239 | video_frames = torch.cat(video_frames, dim=2)
240 | video_frames = video_frames.squeeze(0)
241 | video_frames = video_frames[:, :audio_length]
242 |
243 |
244 | tensor_to_video(video_frames, video_path, temp_audio, fps=fps)
245 | return (video_path, f"✅ Video saved as {video_name}")
246 |
247 | finally:
248 | # Clean up temporary files
249 | if os.path.exists(temp_audio):
250 | os.remove(temp_audio)
251 |
252 | except Exception as e:
253 | import traceback
254 | traceback.print_exc()
255 | return ("", f"❌ Error: {str(e)}")
256 |
257 | # Node mappings
258 | NODE_CLASS_MAPPINGS = {
259 | "IF_MemoAvatar": IF_MemoAvatar
260 | }
261 |
262 | NODE_DISPLAY_NAME_MAPPINGS = {
263 | "IF_MemoAvatar": "IF MemoAvatar 🗣️"
264 | }
--------------------------------------------------------------------------------
/IF_MemoCheckpointLoader.py:
--------------------------------------------------------------------------------
1 | #IF_MemoCheckpointLoader.py
2 | import os
3 | import torch
4 | import folder_paths
5 | import logging
6 | from diffusers import AutoencoderKL
7 | from diffusers.utils import is_xformers_available
8 | from packaging import version
9 | from safetensors.torch import load_file
10 | from huggingface_hub import hf_hub_download
11 |
12 | from memo.models.unet_2d_condition import UNet2DConditionModel
13 | from memo.models.unet_3d import UNet3DConditionModel
14 | from memo.models.image_proj import ImageProjModel
15 | from memo.models.audio_proj import AudioProjModel
16 | from memo.models.emotion_classifier import AudioEmotionClassifierModel
17 | from memo_model_manager import MemoModelManager
18 |
19 | logger = logging.getLogger("memo")
20 |
21 | class IF_MemoCheckpointLoader:
22 | @classmethod
23 | def INPUT_TYPES(s):
24 | return {
25 | "required": {
26 | "enable_xformers": ("BOOLEAN", {"default": True}),
27 | }
28 | }
29 |
30 | RETURN_TYPES = ("MODEL", "MODEL", "VAE", "IMAGE_PROJ", "AUDIO_PROJ", "EMOTION_CLASSIFIER")
31 | RETURN_NAMES = ("reference_net", "diffusion_net", "vae", "image_proj", "audio_proj", "emotion_classifier")
32 | FUNCTION = "load_checkpoint"
33 | CATEGORY = "ImpactFrames💥🎞️/MemoAvatar"
34 |
35 | def __init__(self):
36 | # Initialize model manager to set up all paths and auxiliary models
37 | self.model_manager = MemoModelManager()
38 | self.paths = self.model_manager.get_model_paths()
39 |
40 | def load_checkpoint(self, enable_xformers=True):
41 | try:
42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
43 | dtype = torch.float16 if str(device) == "cuda" else torch.float32
44 |
45 | logger.info("Loading models")
46 |
47 | # Fallback download function
48 | def fallback_download(repo_id, filename, local_dir):
49 | try:
50 | return hf_hub_download(
51 | repo_id=repo_id,
52 | filename=filename,
53 | local_dir=local_dir,
54 | local_dir_use_symlinks=False
55 | )
56 | except Exception as e:
57 | logger.warning(f"Failed to download {filename} from {repo_id}: {e}")
58 | return None
59 |
60 | # Load VAE with multiple fallback strategies
61 | try:
62 | vae = AutoencoderKL.from_pretrained(
63 | self.paths["vae"],
64 | use_safetensors=True,
65 | torch_dtype=dtype
66 | ).to(device=device)
67 | except Exception as e:
68 | logger.warning(f"Local VAE load failed: {e}. Attempting download.")
69 | fallback_download(
70 | "stabilityai/sd-vae-ft-mse",
71 | "diffusion_pytorch_model.safetensors",
72 | self.paths["vae"]
73 | )
74 | vae = AutoencoderKL.from_pretrained(
75 | "stabilityai/sd-vae-ft-mse",
76 | use_safetensors=True,
77 | torch_dtype=dtype
78 | ).to(device=device)
79 |
80 | # Load reference net
81 | reference_net = UNet2DConditionModel.from_pretrained(
82 | self.paths["memo_base"],
83 | subfolder="reference_net",
84 | use_safetensors=True
85 | )
86 | reference_net.requires_grad_(False)
87 | reference_net.eval()
88 |
89 | # Load diffusion net
90 | diffusion_net = UNet3DConditionModel.from_pretrained(
91 | self.paths["memo_base"],
92 | subfolder="diffusion_net",
93 | use_safetensors=True
94 | )
95 | diffusion_net.requires_grad_(False)
96 | diffusion_net.eval()
97 |
98 | # Load projectors
99 | image_proj = ImageProjModel.from_pretrained(
100 | self.paths["memo_base"],
101 | subfolder="image_proj",
102 | use_safetensors=True
103 | )
104 | image_proj.requires_grad_(False)
105 | image_proj.eval()
106 |
107 | audio_proj = AudioProjModel.from_pretrained(
108 | self.paths["memo_base"],
109 | subfolder="audio_proj",
110 | use_safetensors=True
111 | )
112 | audio_proj.requires_grad_(False)
113 | audio_proj.eval()
114 |
115 | # Enable xformers (Optional)
116 | if enable_xformers and is_xformers_available():
117 | try:
118 | import xformers
119 | xformers_version = version.parse(xformers.__version__)
120 | if xformers_version == version.parse("0.0.16"):
121 | logger.warning("xFormers 0.0.16 cannot be used for training in some GPUs.")
122 | reference_net.enable_xformers_memory_efficient_attention()
123 | diffusion_net.enable_xformers_memory_efficient_attention()
124 | except Exception as e:
125 | logger.warning(f"Could not enable xformers: {e}. Proceeding without xformers.")
126 | else:
127 | logger.info("Xformers is not enabled or not available. Proceeding without xformers.")
128 |
129 | # Move models to device
130 | for model in [reference_net, diffusion_net, image_proj, audio_proj]:
131 | model.to(device=device, dtype=dtype)
132 |
133 | # Load emotion classifier
134 | emotion_classifier = AudioEmotionClassifierModel()
135 | emotion_classifier_path = os.path.join(
136 | self.paths["memo_base"],
137 | "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors"
138 | )
139 | emotion_classifier.load_state_dict(load_file(emotion_classifier_path))
140 | emotion_classifier.to(device=device, dtype=dtype)
141 | emotion_classifier.eval()
142 |
143 | logger.info(f"Models loaded successfully to {device} with dtype {dtype}")
144 | return (reference_net, diffusion_net, vae, image_proj, audio_proj, emotion_classifier)
145 |
146 | except Exception as e:
147 | logger.error(f"Comprehensive model loading error: {e}")
148 | import traceback
149 | traceback.print_exc()
150 | raise RuntimeError(f"Failed to load models: {str(e)}")
151 |
152 | @classmethod
153 | def IS_CHANGED(s, **kwargs):
154 | return float("nan")
155 |
156 | NODE_CLASS_MAPPINGS = {
157 | "IF_MemoCheckpointLoader": IF_MemoCheckpointLoader
158 | }
159 |
160 | NODE_DISPLAY_NAME_MAPPINGS = {
161 | "IF_MemoCheckpointLoader": "IF Memo Checkpoint Loader 🎬"
162 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2024 ImpactFrames
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-IF_MemoAvatar
2 | Memory-Guided Diffusion for Expressive Talking Video Generation
3 |
4 |
5 | 
6 |
7 | #ORIGINAL REPO
8 | **MEMO: Memory-Guided Diffusion for Expressive Talking Video Generation**
9 |
10 | [Longtao Zheng](https://ltzheng.github.io)\*,
11 | [Yifan Zhang](https://scholar.google.com/citations?user=zuYIUJEAAAAJ)\*,
12 | [Hanzhong Guo](https://scholar.google.com/citations?user=q3x6KsgAAAAJ)\,
13 | [Jiachun Pan](https://scholar.google.com/citations?user=nrOvfb4AAAAJ),
14 | [Zhenxiong Tan](https://scholar.google.com/citations?user=HP9Be6UAAAAJ),
15 | [Jiahao Lu](https://scholar.google.com/citations?user=h7rbA-sAAAAJ),
16 | [Chuanxin Tang](https://scholar.google.com/citations?user=3ZC8B7MAAAAJ),
17 | [Bo An](https://personal.ntu.edu.sg/boan/index.html),
18 | [Shuicheng Yan](https://scholar.google.com/citations?user=DNuiPHwAAAAJ)
19 |
20 | _[Project Page](https://memoavatar.github.io) | [arXiv](https://arxiv.org/abs/2412.04448) | [Model](https://huggingface.co/memoavatar/memo)_
21 |
22 | This repository contains the example inference script for the MEMO-preview model. The gif demo below is compressed. See our [project page](https://memoavatar.github.io) for full videos.
23 |
24 |
25 |

26 |
27 |
28 | # ComfyUI-IF_MemoAvatar
29 | Memory-Guided Diffusion for Expressive Talking Video Generation
30 |
31 | ## Overview
32 | This is a ComfyUI implementation of MEMO (Memory-Guided Diffusion for Expressive Talking Video Generation), which enables the creation of expressive talking avatar videos from a single image and audio input.
33 |
34 | ## Features
35 | - Generate expressive talking head videos from a single image
36 | - Audio-driven facial animation
37 | - Emotional expression transfer
38 | - High-quality video output
39 | 
40 |
41 |
42 | https://github.com/user-attachments/assets/bfbf896d-a609-4e0f-8ed3-16ec48f8d85a
43 |
44 |
45 | ## Installation
46 |
47 | *** Xformers NOT REQUIRED BUT BETTER IF INSTALLED***
48 | *** MAKE SURE YoU HAVE HF Token On Your environment VARIABLES ***
49 |
50 | git clone the repo to your custom_nodes folder and then
51 | ```bash
52 | cd ComfyUI-IF_MemoAvatar
53 | pip install -r requirements.txt
54 | ```
55 | I removed xformers from the file because it needs a particular combination of pytorch on windows to work
56 |
57 | if you are on linux you can just run
58 | ```bash
59 | pip install xformers
60 | ```
61 | for windows users if you don't have xformers on your env
62 | ```bash
63 | pip show xformers
64 | ```
65 | follow this guide to install a good comfyui environment if you don't see any version install the latest following this free guide
66 |
67 | [Installing Triton and Sage Attention Flash Attention](https://ko-fi.com/post/Installing-Triton-and-Sage-Attention-Flash-Attenti-P5P8175434)
68 |
69 |
70 | [](https://www.youtube.com/watch?v=nSUGEdm2wU4)
71 |
72 |
73 | ### Model Files
74 | The models will automatically download to the following locations in your ComfyUI installation:
75 |
76 | ```bash
77 | models/checkpoints/memo/
78 | ├── audio_proj/
79 | ├── diffusion_net/
80 | ├── image_proj/
81 | ├── misc/
82 | │ ├── audio_emotion_classifier/
83 | │ ├── face_analysis/
84 | │ └── vocal_separator/
85 | └── reference_net/
86 | models/wav2vec/
87 | models/vae/sd-vae-ft-mse/
88 | models/emotion2vec/emotion2vec_plus_large/
89 |
90 | ```
91 |
92 | Copy the faceanalisys/models models from the folder directly into faceanalisys
93 | just until I make sure don't just move then duplicate them cos
94 | HF will detect empty and download them every time
95 | If you don't see a `models.json` or errors out create one yourself this is the content
96 | ```bash
97 | {
98 | "detection": [
99 | "scrfd_10g_bnkps"
100 | ],
101 | "recognition": [
102 | "glintr100"
103 | ],
104 | "analysis": [
105 | "genderage",
106 | "2d106det",
107 | "1k3d68"
108 | ]
109 | }
110 | ```
111 | and a `version.txt` containing
112 | `0.7.3`
113 |
114 | 
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | #__init__.py
2 | import os
3 | import sys
4 | from pathlib import Path
5 |
6 | # Get the absolute path to the current directory and memo directory
7 | CURRENT_DIR = Path(__file__).parent.absolute()
8 | MEMO_DIR = CURRENT_DIR / "memo"
9 |
10 | # Add both directories to Python path if they're not already there
11 | if str(CURRENT_DIR) not in sys.path:
12 | sys.path.insert(0, str(CURRENT_DIR))
13 | if str(MEMO_DIR) not in sys.path:
14 | sys.path.insert(0, str(MEMO_DIR))
15 |
16 | # Create an empty __init__.py in memo directory if it doesn't exist
17 | memo_init = MEMO_DIR / "__init__.py"
18 | if not memo_init.exists():
19 | memo_init.touch()
20 |
21 | # Now import the components using absolute imports
22 | from .memo_model_manager import MemoModelManager
23 | from .IF_MemoAvatar import IF_MemoAvatar
24 | from .IF_MemoCheckpointLoader import IF_MemoCheckpointLoader
25 |
26 | NODE_CLASS_MAPPINGS = {
27 | "IF_MemoAvatar": IF_MemoAvatar,
28 | "IF_MemoCheckpointLoader": IF_MemoCheckpointLoader,
29 | }
30 |
31 | NODE_DISPLAY_NAME_MAPPINGS = {
32 | "IF_MemoAvatar": "IF MemoAvatar 🗣️",
33 | "IF_MemoCheckpointLoader": "IF Memo Checkpoint Loader"
34 | }
35 |
36 | # Define web directory relative to this file
37 | WEB_DIRECTORY = os.path.join(os.path.dirname(__file__), "web")
38 |
39 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
--------------------------------------------------------------------------------
/configs/inference.yaml:
--------------------------------------------------------------------------------
1 | resolution: 512
2 | num_generated_frames_per_clip: 16
3 | fps: 30
4 | num_init_past_frames: 2
5 | num_past_frames: 16
6 | inference_steps: 20
7 | cfg_scale: 3.5
8 | weight_dtype: bf16
9 | enable_xformers_memory_efficient_attention: true
10 |
11 | model_name_or_path: memoavatar/memo
12 | # model_name_or_path: checkpoints
13 | vae: stabilityai/sd-vae-ft-mse
14 | wav2vec: facebook/wav2vec2-base-960h
15 | emotion2vec: emotion2vec/emotion2vec_plus_large
16 | misc_model_dir: checkpoints
17 |
18 |
--------------------------------------------------------------------------------
/examples/candy.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/examples/candy.wav
--------------------------------------------------------------------------------
/examples/candy@2x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/examples/candy@2x.png
--------------------------------------------------------------------------------
/examples/dicaprio.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/examples/dicaprio.jpg
--------------------------------------------------------------------------------
/examples/speech.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/examples/speech.wav
--------------------------------------------------------------------------------
/memo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/memo/__init__.py
--------------------------------------------------------------------------------
/memo/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/memo/models/__init__.py
--------------------------------------------------------------------------------
/memo/models/audio_proj.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import ConfigMixin, ModelMixin
3 | from einops import rearrange
4 | from torch import nn
5 |
6 |
7 | class AudioProjModel(ModelMixin, ConfigMixin):
8 | def __init__(
9 | self,
10 | seq_len=5,
11 | blocks=12, # add a new parameter blocks
12 | channels=768, # add a new parameter channels
13 | intermediate_dim=512,
14 | output_dim=768,
15 | context_tokens=32,
16 | ):
17 | super().__init__()
18 |
19 | self.seq_len = seq_len
20 | self.blocks = blocks
21 | self.channels = channels
22 | self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
23 | self.intermediate_dim = intermediate_dim
24 | self.context_tokens = context_tokens
25 | self.output_dim = output_dim
26 |
27 | # define multiple linear layers
28 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
29 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
30 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
31 |
32 | self.norm = nn.LayerNorm(output_dim)
33 |
34 | def forward(self, audio_embeds):
35 | video_length = audio_embeds.shape[1]
36 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
37 | batch_size, window_size, blocks, channels = audio_embeds.shape
38 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
39 |
40 | audio_embeds = torch.relu(self.proj1(audio_embeds))
41 | audio_embeds = torch.relu(self.proj2(audio_embeds))
42 |
43 | context_tokens = self.proj3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
44 |
45 | context_tokens = self.norm(context_tokens)
46 | context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
47 |
48 | return context_tokens
49 |
--------------------------------------------------------------------------------
/memo/models/emotion_classifier.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from diffusers import ConfigMixin, ModelMixin
5 |
6 |
7 | class AudioEmotionClassifierModel(ModelMixin, ConfigMixin):
8 | num_emotion_classes = 9
9 |
10 | def __init__(self, num_classifier_layers=5, num_classifier_channels=2048):
11 | super().__init__()
12 |
13 | if num_classifier_layers == 1:
14 | self.layers = torch.nn.Linear(1024, self.num_emotion_classes)
15 | else:
16 | layer_list = [
17 | ("fc1", torch.nn.Linear(1024, num_classifier_channels)),
18 | ("relu1", torch.nn.ReLU()),
19 | ]
20 | for n in range(num_classifier_layers - 2):
21 | layer_list.append((f"fc{n+2}", torch.nn.Linear(num_classifier_channels, num_classifier_channels)))
22 | layer_list.append((f"relu{n+2}", torch.nn.ReLU()))
23 | layer_list.append(
24 | (f"fc{num_classifier_layers}", torch.nn.Linear(num_classifier_channels, self.num_emotion_classes))
25 | )
26 | self.layers = torch.nn.Sequential(OrderedDict(layer_list))
27 |
28 | def forward(self, x):
29 | x = self.layers(x)
30 | x = torch.softmax(x, dim=-1)
31 | return x
32 |
--------------------------------------------------------------------------------
/memo/models/image_proj.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import ConfigMixin, ModelMixin
3 |
4 |
5 | class ImageProjModel(ModelMixin, ConfigMixin):
6 | def __init__(
7 | self,
8 | cross_attention_dim=768,
9 | clip_embeddings_dim=512,
10 | clip_extra_context_tokens=4,
11 | ):
12 | super().__init__()
13 |
14 | self.generator = None
15 | self.cross_attention_dim = cross_attention_dim
16 | self.clip_extra_context_tokens = clip_extra_context_tokens
17 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
18 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
19 |
20 | def forward(self, image_embeds):
21 | embeds = image_embeds
22 | clip_extra_context_tokens = self.proj(embeds).reshape(
23 | -1, self.clip_extra_context_tokens, self.cross_attention_dim
24 | )
25 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
26 | return clip_extra_context_tokens
27 |
--------------------------------------------------------------------------------
/memo/models/motion_module.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import xformers
5 | import xformers.ops
6 | from diffusers.models.attention import FeedForward
7 | from diffusers.models.attention_processor import Attention
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from memo.models.attention import zero_module
13 | from memo.models.attention_processor import (
14 | MemoryLinearAttnProcessor,
15 | )
16 |
17 |
18 | class PositionalEncoding(nn.Module):
19 | def __init__(self, d_model, dropout=0.0, max_len=24):
20 | super().__init__()
21 | self.dropout = nn.Dropout(p=dropout)
22 | position = torch.arange(max_len).unsqueeze(1)
23 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
24 | pe = torch.zeros(1, max_len, d_model)
25 | pe[0, :, 0::2] = torch.sin(position * div_term)
26 | pe[0, :, 1::2] = torch.cos(position * div_term)
27 | self.register_buffer("pe", pe)
28 |
29 | def forward(self, x, offset=0):
30 | x = x + self.pe[:, offset : offset + x.size(1)]
31 | return self.dropout(x)
32 |
33 |
34 | class MemoryLinearAttnTemporalModule(nn.Module):
35 | def __init__(
36 | self,
37 | in_channels,
38 | num_attention_heads=8,
39 | num_transformer_block=2,
40 | attention_block_types=("Temporal_Self", "Temporal_Self"),
41 | temporal_position_encoding=False,
42 | temporal_position_encoding_max_len=24,
43 | temporal_attention_dim_div=1,
44 | zero_initialize=True,
45 | ):
46 | super().__init__()
47 |
48 | self.temporal_transformer = TemporalLinearAttnTransformer(
49 | in_channels=in_channels,
50 | num_attention_heads=num_attention_heads,
51 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
52 | num_layers=num_transformer_block,
53 | attention_block_types=attention_block_types,
54 | temporal_position_encoding=temporal_position_encoding,
55 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
56 | )
57 |
58 | if zero_initialize:
59 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
60 |
61 | def forward(
62 | self,
63 | hidden_states,
64 | motion_frames,
65 | encoder_hidden_states,
66 | is_new_audio=True,
67 | update_past_memory=False,
68 | ):
69 | hidden_states = self.temporal_transformer(
70 | hidden_states,
71 | motion_frames,
72 | encoder_hidden_states,
73 | is_new_audio=is_new_audio,
74 | update_past_memory=update_past_memory,
75 | )
76 |
77 | output = hidden_states
78 | return output
79 |
80 |
81 | class TemporalLinearAttnTransformer(nn.Module):
82 | def __init__(
83 | self,
84 | in_channels,
85 | num_attention_heads,
86 | attention_head_dim,
87 | num_layers,
88 | attention_block_types=(
89 | "Temporal_Self",
90 | "Temporal_Self",
91 | ),
92 | dropout=0.0,
93 | norm_num_groups=32,
94 | cross_attention_dim=768,
95 | activation_fn="geglu",
96 | attention_bias=False,
97 | upcast_attention=False,
98 | temporal_position_encoding=False,
99 | temporal_position_encoding_max_len=24,
100 | ):
101 | super().__init__()
102 |
103 | inner_dim = num_attention_heads * attention_head_dim
104 |
105 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
106 | self.proj_in = nn.Linear(in_channels, inner_dim)
107 |
108 | self.transformer_blocks = nn.ModuleList(
109 | [
110 | TemporalLinearAttnTransformerBlock(
111 | dim=inner_dim,
112 | num_attention_heads=num_attention_heads,
113 | attention_head_dim=attention_head_dim,
114 | attention_block_types=attention_block_types,
115 | dropout=dropout,
116 | cross_attention_dim=cross_attention_dim,
117 | activation_fn=activation_fn,
118 | attention_bias=attention_bias,
119 | upcast_attention=upcast_attention,
120 | temporal_position_encoding=temporal_position_encoding,
121 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
122 | )
123 | for _ in range(num_layers)
124 | ]
125 | )
126 | self.proj_out = nn.Linear(inner_dim, in_channels)
127 |
128 | def forward(
129 | self,
130 | hidden_states,
131 | motion_frames,
132 | encoder_hidden_states=None,
133 | is_new_audio=True,
134 | update_past_memory=False,
135 | ):
136 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
137 | video_length = hidden_states.shape[2]
138 | n_motion_frames = motion_frames.shape[2]
139 |
140 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
141 | with torch.no_grad():
142 | motion_frames = rearrange(motion_frames, "b c f h w -> (b f) c h w")
143 |
144 | batch, _, height, weight = hidden_states.shape
145 | residual = hidden_states
146 |
147 | hidden_states = self.norm(hidden_states)
148 | with torch.no_grad():
149 | motion_frames = self.norm(motion_frames)
150 |
151 | inner_dim = hidden_states.shape[1]
152 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
153 | hidden_states = self.proj_in(hidden_states)
154 |
155 | with torch.no_grad():
156 | (
157 | motion_frames_batch,
158 | motion_frames_inner_dim,
159 | motion_frames_height,
160 | motion_frames_weight,
161 | ) = motion_frames.shape
162 |
163 | motion_frames = motion_frames.permute(0, 2, 3, 1).reshape(
164 | motion_frames_batch,
165 | motion_frames_height * motion_frames_weight,
166 | motion_frames_inner_dim,
167 | )
168 | motion_frames = self.proj_in(motion_frames)
169 |
170 | # Transformer Blocks
171 | for block in self.transformer_blocks:
172 | hidden_states = block(
173 | hidden_states,
174 | motion_frames,
175 | encoder_hidden_states=encoder_hidden_states,
176 | video_length=video_length,
177 | n_motion_frames=n_motion_frames,
178 | is_new_audio=is_new_audio,
179 | update_past_memory=update_past_memory,
180 | )
181 |
182 | # output
183 | hidden_states = self.proj_out(hidden_states)
184 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
185 |
186 | output = hidden_states + residual
187 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
188 |
189 | return output
190 |
191 |
192 | class TemporalLinearAttnTransformerBlock(nn.Module):
193 | def __init__(
194 | self,
195 | dim,
196 | num_attention_heads,
197 | attention_head_dim,
198 | attention_block_types=(
199 | "Temporal_Self",
200 | "Temporal_Self",
201 | ),
202 | dropout=0.0,
203 | cross_attention_dim=768,
204 | activation_fn="geglu",
205 | attention_bias=False,
206 | upcast_attention=False,
207 | temporal_position_encoding=False,
208 | temporal_position_encoding_max_len=24,
209 | ):
210 | super().__init__()
211 |
212 | attention_blocks = []
213 | norms = []
214 |
215 | for block_name in attention_block_types:
216 | attention_blocks.append(
217 | MemoryLinearAttention(
218 | attention_mode=block_name.split("_", maxsplit=1)[0],
219 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
220 | query_dim=dim,
221 | heads=num_attention_heads,
222 | dim_head=attention_head_dim,
223 | dropout=dropout,
224 | bias=attention_bias,
225 | upcast_attention=upcast_attention,
226 | temporal_position_encoding=temporal_position_encoding,
227 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
228 | )
229 | )
230 | norms.append(nn.LayerNorm(dim))
231 |
232 | self.attention_blocks = nn.ModuleList(attention_blocks)
233 | self.norms = nn.ModuleList(norms)
234 |
235 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
236 | self.ff_norm = nn.LayerNorm(dim)
237 |
238 | def forward(
239 | self,
240 | hidden_states,
241 | motion_frames,
242 | encoder_hidden_states=None,
243 | video_length=None,
244 | n_motion_frames=None,
245 | is_new_audio=True,
246 | update_past_memory=False,
247 | ):
248 | for attention_block, norm in zip(self.attention_blocks, self.norms):
249 | norm_hidden_states = norm(hidden_states)
250 | with torch.no_grad():
251 | norm_motion_frames = norm(motion_frames)
252 | hidden_states = (
253 | attention_block(
254 | norm_hidden_states,
255 | norm_motion_frames,
256 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
257 | video_length=video_length,
258 | n_motion_frames=n_motion_frames,
259 | is_new_audio=is_new_audio,
260 | update_past_memory=update_past_memory,
261 | )
262 | + hidden_states
263 | )
264 |
265 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
266 |
267 | output = hidden_states
268 | return output
269 |
270 |
271 | class MemoryLinearAttention(Attention):
272 | def __init__(
273 | self,
274 | *args,
275 | attention_mode=None,
276 | temporal_position_encoding=False,
277 | temporal_position_encoding_max_len=24,
278 | **kwargs,
279 | ):
280 | super().__init__(*args, **kwargs)
281 | assert attention_mode == "Temporal"
282 |
283 | self.attention_mode = attention_mode
284 | self.is_cross_attention = kwargs.get("cross_attention_dim") is not None
285 | self.query_dim = kwargs["query_dim"]
286 | self.temporal_position_encoding_max_len = temporal_position_encoding_max_len
287 | self.pos_encoder = (
288 | PositionalEncoding(
289 | kwargs["query_dim"],
290 | dropout=0.0,
291 | max_len=temporal_position_encoding_max_len,
292 | )
293 | if (temporal_position_encoding and attention_mode == "Temporal")
294 | else None
295 | )
296 |
297 | def extra_repr(self):
298 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
299 |
300 | def set_use_memory_efficient_attention_xformers(
301 | self,
302 | use_memory_efficient_attention_xformers: bool,
303 | attention_op=None,
304 | ):
305 | if use_memory_efficient_attention_xformers:
306 | if not is_xformers_available():
307 | raise ModuleNotFoundError(
308 | (
309 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
310 | " xformers"
311 | ),
312 | name="xformers",
313 | )
314 |
315 | if not torch.cuda.is_available():
316 | raise ValueError(
317 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
318 | " only available for GPU "
319 | )
320 |
321 | try:
322 | # Make sure we can run the memory efficient attention
323 | _ = xformers.ops.memory_efficient_attention(
324 | torch.randn((1, 2, 40), device="cuda"),
325 | torch.randn((1, 2, 40), device="cuda"),
326 | torch.randn((1, 2, 40), device="cuda"),
327 | )
328 | except Exception as e:
329 | raise e
330 | processor = MemoryLinearAttnProcessor()
331 | else:
332 | processor = MemoryLinearAttnProcessor()
333 |
334 | self.set_processor(processor)
335 |
336 | def forward(
337 | self,
338 | hidden_states,
339 | motion_frames,
340 | encoder_hidden_states=None,
341 | attention_mask=None,
342 | video_length=None,
343 | n_motion_frames=None,
344 | is_new_audio=True,
345 | update_past_memory=False,
346 | **cross_attention_kwargs,
347 | ):
348 | if self.attention_mode == "Temporal":
349 | d = hidden_states.shape[1]
350 | hidden_states = rearrange(
351 | hidden_states,
352 | "(b f) d c -> (b d) f c",
353 | f=video_length,
354 | )
355 |
356 | if self.pos_encoder is not None:
357 | hidden_states = self.pos_encoder(hidden_states)
358 |
359 | with torch.no_grad():
360 | motion_frames = rearrange(motion_frames, "(b f) d c -> (b d) f c", f=n_motion_frames)
361 |
362 | encoder_hidden_states = (
363 | repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
364 | if encoder_hidden_states is not None
365 | else encoder_hidden_states
366 | )
367 |
368 | else:
369 | raise NotImplementedError
370 |
371 | hidden_states = self.processor(
372 | self,
373 | hidden_states,
374 | motion_frames,
375 | encoder_hidden_states=encoder_hidden_states,
376 | attention_mask=attention_mask,
377 | n_motion_frames=n_motion_frames,
378 | is_new_audio=is_new_audio,
379 | update_past_memory=update_past_memory,
380 | **cross_attention_kwargs,
381 | )
382 |
383 | if self.attention_mode == "Temporal":
384 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
385 |
386 | return hidden_states
387 |
--------------------------------------------------------------------------------
/memo/models/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class FP32LayerNorm(nn.LayerNorm):
7 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
8 | origin_dtype = inputs.dtype
9 | return F.layer_norm(
10 | inputs.float(),
11 | self.normalized_shape,
12 | self.weight.float() if self.weight is not None else None,
13 | self.bias.float() if self.bias is not None else None,
14 | self.eps,
15 | ).to(origin_dtype)
16 |
--------------------------------------------------------------------------------
/memo/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange
4 | from torch import nn
5 |
6 |
7 | class InflatedConv3d(nn.Conv2d):
8 | def forward(self, x):
9 | video_length = x.shape[2]
10 |
11 | x = rearrange(x, "b c f h w -> (b f) c h w")
12 | x = super().forward(x)
13 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
14 |
15 | return x
16 |
17 |
18 | class InflatedGroupNorm(nn.GroupNorm):
19 | def forward(self, x):
20 | video_length = x.shape[2]
21 |
22 | x = rearrange(x, "b c f h w -> (b f) c h w")
23 | x = super().forward(x)
24 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
25 |
26 | return x
27 |
28 |
29 | class Upsample3D(nn.Module):
30 | def __init__(
31 | self,
32 | channels,
33 | use_conv=False,
34 | use_conv_transpose=False,
35 | out_channels=None,
36 | name="conv",
37 | ):
38 | super().__init__()
39 | self.channels = channels
40 | self.out_channels = out_channels or channels
41 | self.use_conv = use_conv
42 | self.use_conv_transpose = use_conv_transpose
43 | self.name = name
44 |
45 | if use_conv_transpose:
46 | raise NotImplementedError
47 | if use_conv:
48 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
49 |
50 | def forward(self, hidden_states, output_size=None):
51 | assert hidden_states.shape[1] == self.channels
52 |
53 | if self.use_conv_transpose:
54 | raise NotImplementedError
55 |
56 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
57 | dtype = hidden_states.dtype
58 | if dtype == torch.bfloat16:
59 | hidden_states = hidden_states.to(torch.float32)
60 |
61 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
62 | if hidden_states.shape[0] >= 64:
63 | hidden_states = hidden_states.contiguous()
64 |
65 | # if `output_size` is passed we force the interpolation output
66 | # size and do not make use of `scale_factor=2`
67 | if output_size is None:
68 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
69 | else:
70 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
71 |
72 | # If the input is bfloat16, we cast back to bfloat16
73 | if dtype == torch.bfloat16:
74 | hidden_states = hidden_states.to(dtype)
75 |
76 | hidden_states = self.conv(hidden_states)
77 |
78 | return hidden_states
79 |
80 |
81 | class Downsample3D(nn.Module):
82 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
83 | super().__init__()
84 | self.channels = channels
85 | self.out_channels = out_channels or channels
86 | self.use_conv = use_conv
87 | self.padding = padding
88 | stride = 2
89 | self.name = name
90 |
91 | if use_conv:
92 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
93 | else:
94 | raise NotImplementedError
95 |
96 | def forward(self, hidden_states):
97 | assert hidden_states.shape[1] == self.channels
98 | if self.use_conv and self.padding == 0:
99 | raise NotImplementedError
100 |
101 | assert hidden_states.shape[1] == self.channels
102 | hidden_states = self.conv(hidden_states)
103 |
104 | return hidden_states
105 |
106 |
107 | class ResnetBlock3D(nn.Module):
108 | def __init__(
109 | self,
110 | *,
111 | in_channels,
112 | out_channels=None,
113 | conv_shortcut=False,
114 | dropout=0.0,
115 | temb_channels=512,
116 | groups=32,
117 | groups_out=None,
118 | pre_norm=True,
119 | eps=1e-6,
120 | non_linearity="swish",
121 | time_embedding_norm="default",
122 | output_scale_factor=1.0,
123 | use_in_shortcut=None,
124 | use_inflated_groupnorm=None,
125 | ):
126 | super().__init__()
127 | self.pre_norm = pre_norm
128 | self.pre_norm = True
129 | self.in_channels = in_channels
130 | out_channels = in_channels if out_channels is None else out_channels
131 | self.out_channels = out_channels
132 | self.use_conv_shortcut = conv_shortcut
133 | self.time_embedding_norm = time_embedding_norm
134 | self.output_scale_factor = output_scale_factor
135 |
136 | if groups_out is None:
137 | groups_out = groups
138 |
139 | assert use_inflated_groupnorm is not None
140 | if use_inflated_groupnorm:
141 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
142 | else:
143 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
144 |
145 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
146 |
147 | if temb_channels is not None:
148 | if self.time_embedding_norm == "default":
149 | time_emb_proj_out_channels = out_channels
150 | elif self.time_embedding_norm == "scale_shift":
151 | time_emb_proj_out_channels = out_channels * 2
152 | else:
153 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
154 |
155 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
156 | else:
157 | self.time_emb_proj = None
158 |
159 | if use_inflated_groupnorm:
160 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161 | else:
162 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
163 | self.dropout = torch.nn.Dropout(dropout)
164 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
165 |
166 | if non_linearity == "swish":
167 | self.nonlinearity = F.silu()
168 | elif non_linearity == "mish":
169 | self.nonlinearity = Mish()
170 | elif non_linearity == "silu":
171 | self.nonlinearity = nn.SiLU()
172 |
173 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
174 |
175 | self.conv_shortcut = None
176 | if self.use_in_shortcut:
177 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
178 |
179 | def forward(self, input_tensor, temb):
180 | hidden_states = input_tensor
181 |
182 | hidden_states = self.norm1(hidden_states)
183 | hidden_states = self.nonlinearity(hidden_states)
184 |
185 | hidden_states = self.conv1(hidden_states)
186 |
187 | if temb is not None:
188 | if temb.dim() == 3:
189 | temb = self.time_emb_proj(self.nonlinearity(temb))
190 | temb = temb.transpose(1, 2).unsqueeze(-1).unsqueeze(-1)
191 | else:
192 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
193 |
194 | if temb is not None and self.time_embedding_norm == "default":
195 | hidden_states = hidden_states + temb
196 |
197 | hidden_states = self.norm2(hidden_states)
198 |
199 | if temb is not None and self.time_embedding_norm == "scale_shift":
200 | scale, shift = torch.chunk(temb, 2, dim=1)
201 | hidden_states = hidden_states * (1 + scale) + shift
202 |
203 | hidden_states = self.nonlinearity(hidden_states)
204 |
205 | hidden_states = self.dropout(hidden_states)
206 | hidden_states = self.conv2(hidden_states)
207 |
208 | if self.conv_shortcut is not None:
209 | input_tensor = self.conv_shortcut(input_tensor)
210 |
211 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
212 |
213 | return output_tensor
214 |
215 |
216 | class Mish(torch.nn.Module):
217 | def forward(self, hidden_states):
218 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
219 |
--------------------------------------------------------------------------------
/memo/models/transformer_2d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
7 | from diffusers.models.modeling_utils import ModelMixin
8 | from diffusers.models.normalization import AdaLayerNormSingle
9 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
10 | from torch import nn
11 |
12 | from memo.models.attention import BasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer2DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 | ref_feature_list: list[torch.FloatTensor]
19 |
20 |
21 | class Transformer2DModel(ModelMixin, ConfigMixin):
22 | _supports_gradient_checkpointing = True
23 |
24 | @register_to_config
25 | def __init__(
26 | self,
27 | num_attention_heads: int = 16,
28 | attention_head_dim: int = 88,
29 | in_channels: Optional[int] = None,
30 | out_channels: Optional[int] = None,
31 | num_layers: int = 1,
32 | dropout: float = 0.0,
33 | norm_num_groups: int = 32,
34 | cross_attention_dim: Optional[int] = None,
35 | attention_bias: bool = False,
36 | num_vector_embeds: Optional[int] = None,
37 | patch_size: Optional[int] = None,
38 | activation_fn: str = "geglu",
39 | num_embeds_ada_norm: Optional[int] = None,
40 | use_linear_projection: bool = False,
41 | only_cross_attention: bool = False,
42 | double_self_attention: bool = False,
43 | upcast_attention: bool = False,
44 | norm_type: str = "layer_norm",
45 | norm_elementwise_affine: bool = True,
46 | norm_eps: float = 1e-5,
47 | attention_type: str = "default",
48 | is_final_block: bool = False,
49 | ):
50 | super().__init__()
51 | self.use_linear_projection = use_linear_projection
52 | self.num_attention_heads = num_attention_heads
53 | self.attention_head_dim = attention_head_dim
54 | self.is_final_block = is_final_block
55 | inner_dim = num_attention_heads * attention_head_dim
56 |
57 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
58 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
59 |
60 | # 1. Transformer2DModel can process both standard continuous images of
61 | # shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of
62 | # shape `(batch_size, num_image_vectors)`
63 | # Define whether input is continuous or discrete depending on configuration
64 | self.is_input_continuous = (in_channels is not None) and (patch_size is None)
65 | self.is_input_vectorized = num_vector_embeds is not None
66 | self.is_input_patches = in_channels is not None and patch_size is not None
67 |
68 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
69 | deprecation_message = (
70 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
71 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
72 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
73 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
74 | " would be very nice if you could open a Pull request for the `transformer/config.json` file"
75 | )
76 | deprecate(
77 | "norm_type!=num_embeds_ada_norm",
78 | "1.0.0",
79 | deprecation_message,
80 | standard_warn=False,
81 | )
82 | norm_type = "ada_norm"
83 |
84 | if self.is_input_continuous and self.is_input_vectorized:
85 | raise ValueError(
86 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
87 | " sure that either `in_channels` or `num_vector_embeds` is None."
88 | )
89 |
90 | if self.is_input_vectorized and self.is_input_patches:
91 | raise ValueError(
92 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
93 | " sure that either `num_vector_embeds` or `num_patches` is None."
94 | )
95 |
96 | if not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
97 | raise ValueError(
98 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
99 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
100 | )
101 |
102 | # 2. Define input layers
103 | self.in_channels = in_channels
104 |
105 | self.norm = torch.nn.GroupNorm(
106 | num_groups=norm_num_groups,
107 | num_channels=in_channels,
108 | eps=1e-6,
109 | affine=True,
110 | )
111 | if use_linear_projection:
112 | self.proj_in = linear_cls(in_channels, inner_dim)
113 | else:
114 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
115 |
116 | # 3. Define transformers blocks
117 | self.transformer_blocks = nn.ModuleList(
118 | [
119 | BasicTransformerBlock(
120 | inner_dim,
121 | num_attention_heads,
122 | attention_head_dim,
123 | dropout=dropout,
124 | cross_attention_dim=cross_attention_dim,
125 | activation_fn=activation_fn,
126 | num_embeds_ada_norm=num_embeds_ada_norm,
127 | attention_bias=attention_bias,
128 | only_cross_attention=only_cross_attention,
129 | double_self_attention=double_self_attention,
130 | upcast_attention=upcast_attention,
131 | norm_type=norm_type,
132 | norm_elementwise_affine=norm_elementwise_affine,
133 | norm_eps=norm_eps,
134 | attention_type=attention_type,
135 | is_final_block=(is_final_block and d == num_layers - 1),
136 | )
137 | for d in range(num_layers)
138 | ]
139 | )
140 |
141 | # 4. Define output layers
142 | self.out_channels = in_channels if out_channels is None else out_channels
143 | # TODO: should use out_channels for continuous projections
144 | if not is_final_block:
145 | if use_linear_projection:
146 | self.proj_out = linear_cls(inner_dim, in_channels)
147 | else:
148 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
149 |
150 | # 5. PixArt-Alpha blocks.
151 | self.adaln_single = None
152 | self.use_additional_conditions = False
153 | if norm_type == "ada_norm_single":
154 | self.use_additional_conditions = self.config.sample_size == 128
155 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
156 | # additional conditions until we find better name
157 | self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
158 |
159 | self.caption_projection = None
160 |
161 | self.gradient_checkpointing = False
162 |
163 | def _set_gradient_checkpointing(self, module, value=False):
164 | if hasattr(module, "gradient_checkpointing"):
165 | module.gradient_checkpointing = value
166 |
167 | def forward(
168 | self,
169 | hidden_states: torch.Tensor,
170 | encoder_hidden_states: Optional[torch.Tensor] = None,
171 | timestep: Optional[torch.LongTensor] = None,
172 | class_labels: Optional[torch.LongTensor] = None,
173 | cross_attention_kwargs: Dict[str, Any] = None,
174 | attention_mask: Optional[torch.Tensor] = None,
175 | encoder_attention_mask: Optional[torch.Tensor] = None,
176 | return_dict: bool = True,
177 | ):
178 | if attention_mask is not None and attention_mask.ndim == 2:
179 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
180 | attention_mask = attention_mask.unsqueeze(1)
181 |
182 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
183 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
184 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
185 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
186 |
187 | # Retrieve lora scale.
188 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
189 |
190 | # 1. Input
191 | batch, _, height, width = hidden_states.shape
192 | residual = hidden_states
193 |
194 | hidden_states = self.norm(hidden_states)
195 | if not self.use_linear_projection:
196 | hidden_states = (
197 | self.proj_in(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_in(hidden_states)
198 | )
199 | inner_dim = hidden_states.shape[1]
200 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
201 | else:
202 | inner_dim = hidden_states.shape[1]
203 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
204 | hidden_states = (
205 | self.proj_in(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_in(hidden_states)
206 | )
207 |
208 | # 2. Blocks
209 | if self.caption_projection is not None:
210 | batch_size = hidden_states.shape[0]
211 | encoder_hidden_states = self.caption_projection(encoder_hidden_states)
212 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
213 |
214 | ref_feature_list = []
215 | for block in self.transformer_blocks:
216 | if self.training and self.gradient_checkpointing:
217 |
218 | def create_custom_forward(module, return_dict=None):
219 | def custom_forward(*inputs):
220 | if return_dict is not None:
221 | return module(*inputs, return_dict=return_dict)
222 |
223 | return module(*inputs)
224 |
225 | return custom_forward
226 |
227 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
228 | hidden_states, ref_feature = torch.utils.checkpoint.checkpoint(
229 | create_custom_forward(block),
230 | hidden_states,
231 | attention_mask,
232 | encoder_hidden_states,
233 | encoder_attention_mask,
234 | timestep,
235 | cross_attention_kwargs,
236 | class_labels,
237 | **ckpt_kwargs,
238 | )
239 | else:
240 | hidden_states, ref_feature = block(
241 | hidden_states, # shape [5, 4096, 320]
242 | attention_mask=attention_mask,
243 | encoder_hidden_states=encoder_hidden_states, # shape [1,4,768]
244 | encoder_attention_mask=encoder_attention_mask,
245 | timestep=timestep,
246 | cross_attention_kwargs=cross_attention_kwargs,
247 | class_labels=class_labels,
248 | )
249 | ref_feature_list.append(ref_feature)
250 |
251 | # 3. Output
252 | output = None
253 |
254 | if self.is_final_block:
255 | if not return_dict:
256 | return (output, ref_feature_list)
257 |
258 | return Transformer2DModelOutput(sample=output, ref_feature_list=ref_feature_list)
259 |
260 | if self.is_input_continuous:
261 | if not self.use_linear_projection:
262 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
263 | hidden_states = (
264 | self.proj_out(hidden_states, scale=lora_scale)
265 | if not USE_PEFT_BACKEND
266 | else self.proj_out(hidden_states)
267 | )
268 | else:
269 | hidden_states = (
270 | self.proj_out(hidden_states, scale=lora_scale)
271 | if not USE_PEFT_BACKEND
272 | else self.proj_out(hidden_states)
273 | )
274 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
275 |
276 | output = hidden_states + residual
277 | if not return_dict:
278 | return (output, ref_feature_list)
279 |
280 | return Transformer2DModelOutput(sample=output, ref_feature_list=ref_feature_list)
281 |
--------------------------------------------------------------------------------
/memo/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from einops import rearrange, repeat
9 | from torch import nn
10 |
11 | from memo.models.attention import JointAudioTemporalBasicTransformerBlock, TemporalBasicTransformerBlock
12 |
13 |
14 | def create_custom_forward(module, return_dict=None):
15 | def custom_forward(*inputs):
16 | if return_dict is not None:
17 | return module(*inputs, return_dict=return_dict)
18 |
19 | return module(*inputs)
20 |
21 | return custom_forward
22 |
23 |
24 | @dataclass
25 | class Transformer3DModelOutput(BaseOutput):
26 | sample: torch.FloatTensor
27 |
28 |
29 | class Transformer3DModel(ModelMixin, ConfigMixin):
30 | _supports_gradient_checkpointing = True
31 |
32 | @register_to_config
33 | def __init__(
34 | self,
35 | num_attention_heads: int = 16,
36 | attention_head_dim: int = 88,
37 | in_channels: Optional[int] = None,
38 | num_layers: int = 1,
39 | dropout: float = 0.0,
40 | norm_num_groups: int = 32,
41 | cross_attention_dim: Optional[int] = None,
42 | attention_bias: bool = False,
43 | activation_fn: str = "geglu",
44 | use_linear_projection: bool = False,
45 | only_cross_attention: bool = False,
46 | upcast_attention: bool = False,
47 | unet_use_cross_frame_attention=None,
48 | unet_use_temporal_attention=None,
49 | use_audio_module=False,
50 | depth=0,
51 | unet_block_name=None,
52 | emo_drop_rate=0.3,
53 | is_final_block=False,
54 | ):
55 | super().__init__()
56 | self.use_linear_projection = use_linear_projection
57 | self.num_attention_heads = num_attention_heads
58 | self.attention_head_dim = attention_head_dim
59 | inner_dim = num_attention_heads * attention_head_dim
60 | self.use_audio_module = use_audio_module
61 | # Define input layers
62 | self.in_channels = in_channels
63 | self.is_final_block = is_final_block
64 |
65 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
66 | if use_linear_projection:
67 | self.proj_in = nn.Linear(in_channels, inner_dim)
68 | else:
69 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
70 |
71 | if use_audio_module:
72 | self.transformer_blocks = nn.ModuleList(
73 | [
74 | JointAudioTemporalBasicTransformerBlock(
75 | dim=inner_dim,
76 | num_attention_heads=num_attention_heads,
77 | attention_head_dim=attention_head_dim,
78 | dropout=dropout,
79 | cross_attention_dim=cross_attention_dim,
80 | activation_fn=activation_fn,
81 | attention_bias=attention_bias,
82 | only_cross_attention=only_cross_attention,
83 | upcast_attention=upcast_attention,
84 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85 | unet_use_temporal_attention=unet_use_temporal_attention,
86 | depth=depth,
87 | unet_block_name=unet_block_name,
88 | use_ada_layer_norm=True,
89 | emo_drop_rate=emo_drop_rate,
90 | is_final_block=(is_final_block and d == num_layers - 1),
91 | )
92 | for d in range(num_layers)
93 | ]
94 | )
95 | else:
96 | self.transformer_blocks = nn.ModuleList(
97 | [
98 | TemporalBasicTransformerBlock(
99 | inner_dim,
100 | num_attention_heads,
101 | attention_head_dim,
102 | dropout=dropout,
103 | cross_attention_dim=cross_attention_dim,
104 | activation_fn=activation_fn,
105 | attention_bias=attention_bias,
106 | only_cross_attention=only_cross_attention,
107 | upcast_attention=upcast_attention,
108 | )
109 | for _ in range(num_layers)
110 | ]
111 | )
112 |
113 | # 4. Define output layers
114 | if use_linear_projection:
115 | self.proj_out = nn.Linear(in_channels, inner_dim)
116 | else:
117 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
118 |
119 | self.gradient_checkpointing = False
120 |
121 | def _set_gradient_checkpointing(self, module, value=False):
122 | if hasattr(module, "gradient_checkpointing"):
123 | module.gradient_checkpointing = value
124 |
125 | def forward(
126 | self,
127 | hidden_states,
128 | ref_img_feature=None,
129 | encoder_hidden_states=None,
130 | attention_mask=None,
131 | timestep=None,
132 | emotion=None,
133 | uc_mask=None,
134 | return_dict: bool = True,
135 | ):
136 | # Input
137 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
138 | video_length = hidden_states.shape[2]
139 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
140 |
141 | if self.use_audio_module:
142 | if encoder_hidden_states.dim() == 4:
143 | encoder_hidden_states = rearrange(
144 | encoder_hidden_states,
145 | "bs f margin dim -> (bs f) margin dim",
146 | )
147 | else:
148 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
149 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length)
150 |
151 | batch, _, height, weight = hidden_states.shape
152 | residual = hidden_states
153 | if self.use_audio_module:
154 | residual_audio = encoder_hidden_states
155 |
156 | hidden_states = self.norm(hidden_states)
157 | if not self.use_linear_projection:
158 | hidden_states = self.proj_in(hidden_states)
159 | inner_dim = hidden_states.shape[1]
160 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
161 | else:
162 | inner_dim = hidden_states.shape[1]
163 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
164 | hidden_states = self.proj_in(hidden_states)
165 |
166 | # Blocks
167 | for block in self.transformer_blocks:
168 | if self.training and self.gradient_checkpointing:
169 | if isinstance(block, TemporalBasicTransformerBlock):
170 | hidden_states = torch.utils.checkpoint.checkpoint(
171 | create_custom_forward(block),
172 | hidden_states,
173 | ref_img_feature,
174 | None, # attention_mask
175 | encoder_hidden_states,
176 | timestep,
177 | None, # cross_attention_kwargs
178 | video_length,
179 | uc_mask,
180 | )
181 | elif isinstance(block, JointAudioTemporalBasicTransformerBlock):
182 | (
183 | hidden_states,
184 | encoder_hidden_states,
185 | ) = torch.utils.checkpoint.checkpoint(
186 | create_custom_forward(block),
187 | hidden_states,
188 | encoder_hidden_states,
189 | attention_mask,
190 | emotion,
191 | )
192 | else:
193 | hidden_states = torch.utils.checkpoint.checkpoint(
194 | create_custom_forward(block),
195 | hidden_states,
196 | encoder_hidden_states,
197 | timestep,
198 | attention_mask,
199 | video_length,
200 | )
201 | else:
202 | if isinstance(block, TemporalBasicTransformerBlock):
203 | hidden_states = block(
204 | hidden_states=hidden_states,
205 | ref_img_feature=ref_img_feature,
206 | encoder_hidden_states=encoder_hidden_states,
207 | timestep=timestep,
208 | video_length=video_length,
209 | uc_mask=uc_mask,
210 | )
211 | elif isinstance(block, JointAudioTemporalBasicTransformerBlock):
212 | hidden_states, encoder_hidden_states = block(
213 | hidden_states, # shape [2, 4096, 320]
214 | encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640]
215 | attention_mask=attention_mask,
216 | emotion=emotion,
217 | )
218 | else:
219 | hidden_states = block(
220 | hidden_states, # shape [2, 4096, 320]
221 | encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640]
222 | attention_mask=attention_mask,
223 | timestep=timestep,
224 | video_length=video_length,
225 | )
226 |
227 | # Output
228 | if not self.use_linear_projection:
229 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
230 | hidden_states = self.proj_out(hidden_states)
231 | else:
232 | hidden_states = self.proj_out(hidden_states)
233 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
234 |
235 | output = hidden_states + residual
236 |
237 | if self.use_audio_module and not self.is_final_block:
238 | audio_output = encoder_hidden_states + residual_audio
239 | else:
240 | audio_output = None
241 |
242 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
243 | if not return_dict:
244 | if self.use_audio_module:
245 | return output, audio_output
246 | else:
247 | return output
248 |
249 | if self.use_audio_module:
250 | return output, audio_output
251 | else:
252 | return output
253 |
--------------------------------------------------------------------------------
/memo/models/unet_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, List, Optional, Tuple, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.utils.checkpoint
7 | from diffusers.configuration_utils import ConfigMixin, register_to_config
8 | from diffusers.models.attention_processor import AttentionProcessor
9 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
10 | from diffusers.models.modeling_utils import ModelMixin
11 | from diffusers.utils import BaseOutput, logging
12 |
13 | from memo.models.resnet import InflatedConv3d, InflatedGroupNorm
14 | from memo.models.unet_3d_blocks import (
15 | UNetMidBlock3DCrossAttn,
16 | get_down_block,
17 | get_up_block,
18 | )
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 |
24 | @dataclass
25 | class UNet3DConditionOutput(BaseOutput):
26 | sample: torch.FloatTensor
27 |
28 |
29 | class UNet3DConditionModel(ModelMixin, ConfigMixin):
30 | _supports_gradient_checkpointing = True
31 |
32 | @register_to_config
33 | def __init__(
34 | self,
35 | sample_size: Optional[int] = None,
36 | in_channels: int = 8,
37 | out_channels: int = 8,
38 | flip_sin_to_cos: bool = True,
39 | freq_shift: int = 0,
40 | down_block_types: Tuple[str] = (
41 | "CrossAttnDownBlock3D",
42 | "CrossAttnDownBlock3D",
43 | "CrossAttnDownBlock3D",
44 | "DownBlock3D",
45 | ),
46 | mid_block_type: str = "UNetMidBlock3DCrossAttn",
47 | up_block_types: Tuple[str] = (
48 | "UpBlock3D",
49 | "CrossAttnUpBlock3D",
50 | "CrossAttnUpBlock3D",
51 | "CrossAttnUpBlock3D",
52 | ),
53 | only_cross_attention: Union[bool, Tuple[bool]] = False,
54 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
55 | layers_per_block: int = 2,
56 | downsample_padding: int = 1,
57 | mid_block_scale_factor: float = 1,
58 | act_fn: str = "silu",
59 | norm_num_groups: int = 32,
60 | norm_eps: float = 1e-5,
61 | cross_attention_dim: int = 1280,
62 | attention_head_dim: Union[int, Tuple[int]] = 8,
63 | dual_cross_attention: bool = False,
64 | use_linear_projection: bool = False,
65 | class_embed_type: Optional[str] = None,
66 | num_class_embeds: Optional[int] = None,
67 | upcast_attention: bool = False,
68 | resnet_time_scale_shift: str = "default",
69 | use_inflated_groupnorm=False,
70 | # Additional
71 | motion_module_resolutions=(1, 2, 4, 8),
72 | motion_module_kwargs=None,
73 | unet_use_cross_frame_attention=None,
74 | unet_use_temporal_attention=None,
75 | # audio
76 | audio_attention_dim=768,
77 | emo_drop_rate=0.3,
78 | ):
79 | super().__init__()
80 |
81 | self.sample_size = sample_size
82 | time_embed_dim = block_out_channels[0] * 4
83 |
84 | # input
85 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86 |
87 | # time
88 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
89 | timestep_input_dim = block_out_channels[0]
90 |
91 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
92 |
93 | # class embedding
94 | if class_embed_type is None and num_class_embeds is not None:
95 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
96 | elif class_embed_type == "timestep":
97 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
98 | elif class_embed_type == "identity":
99 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
100 | else:
101 | self.class_embedding = None
102 |
103 | self.down_blocks = nn.ModuleList([])
104 | self.mid_block = None
105 | self.up_blocks = nn.ModuleList([])
106 |
107 | if isinstance(only_cross_attention, bool):
108 | only_cross_attention = [only_cross_attention] * len(down_block_types)
109 |
110 | if isinstance(attention_head_dim, int):
111 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
112 |
113 | # down
114 | output_channel = block_out_channels[0]
115 | for i, down_block_type in enumerate(down_block_types):
116 | res = 2**i
117 | input_channel = output_channel
118 | output_channel = block_out_channels[i]
119 | is_final_block = i == len(block_out_channels) - 1
120 |
121 | down_block = get_down_block(
122 | down_block_type,
123 | num_layers=layers_per_block,
124 | in_channels=input_channel,
125 | out_channels=output_channel,
126 | temb_channels=time_embed_dim,
127 | add_downsample=not is_final_block,
128 | resnet_eps=norm_eps,
129 | resnet_act_fn=act_fn,
130 | resnet_groups=norm_num_groups,
131 | cross_attention_dim=cross_attention_dim,
132 | attn_num_head_channels=attention_head_dim[i],
133 | downsample_padding=downsample_padding,
134 | dual_cross_attention=dual_cross_attention,
135 | use_linear_projection=use_linear_projection,
136 | only_cross_attention=only_cross_attention[i],
137 | upcast_attention=upcast_attention,
138 | resnet_time_scale_shift=resnet_time_scale_shift,
139 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
140 | unet_use_temporal_attention=unet_use_temporal_attention,
141 | use_inflated_groupnorm=use_inflated_groupnorm,
142 | use_motion_module=res in motion_module_resolutions,
143 | motion_module_kwargs=motion_module_kwargs,
144 | audio_attention_dim=audio_attention_dim,
145 | depth=i,
146 | emo_drop_rate=emo_drop_rate,
147 | )
148 | self.down_blocks.append(down_block)
149 |
150 | # mid
151 | if mid_block_type == "UNetMidBlock3DCrossAttn":
152 | self.mid_block = UNetMidBlock3DCrossAttn(
153 | in_channels=block_out_channels[-1],
154 | temb_channels=time_embed_dim,
155 | resnet_eps=norm_eps,
156 | resnet_act_fn=act_fn,
157 | output_scale_factor=mid_block_scale_factor,
158 | resnet_time_scale_shift=resnet_time_scale_shift,
159 | cross_attention_dim=cross_attention_dim,
160 | attn_num_head_channels=attention_head_dim[-1],
161 | resnet_groups=norm_num_groups,
162 | dual_cross_attention=dual_cross_attention,
163 | use_linear_projection=use_linear_projection,
164 | upcast_attention=upcast_attention,
165 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
166 | unet_use_temporal_attention=unet_use_temporal_attention,
167 | use_inflated_groupnorm=use_inflated_groupnorm,
168 | motion_module_kwargs=motion_module_kwargs,
169 | audio_attention_dim=audio_attention_dim,
170 | depth=3,
171 | emo_drop_rate=emo_drop_rate,
172 | )
173 | else:
174 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
175 |
176 | # count how many layers upsample the videos
177 | self.num_upsamplers = 0
178 |
179 | # up
180 | reversed_block_out_channels = list(reversed(block_out_channels))
181 | reversed_attention_head_dim = list(reversed(attention_head_dim))
182 | only_cross_attention = list(reversed(only_cross_attention))
183 | output_channel = reversed_block_out_channels[0]
184 | for i, up_block_type in enumerate(up_block_types):
185 | res = 2 ** (3 - i)
186 | is_final_block = i == len(block_out_channels) - 1
187 |
188 | prev_output_channel = output_channel
189 | output_channel = reversed_block_out_channels[i]
190 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
191 |
192 | # add upsample block for all BUT final layer
193 | if not is_final_block:
194 | add_upsample = True
195 | self.num_upsamplers += 1
196 | else:
197 | add_upsample = False
198 |
199 | up_block = get_up_block(
200 | up_block_type,
201 | num_layers=layers_per_block + 1,
202 | in_channels=input_channel,
203 | out_channels=output_channel,
204 | prev_output_channel=prev_output_channel,
205 | temb_channels=time_embed_dim,
206 | add_upsample=add_upsample,
207 | resnet_eps=norm_eps,
208 | resnet_act_fn=act_fn,
209 | resnet_groups=norm_num_groups,
210 | cross_attention_dim=cross_attention_dim,
211 | attn_num_head_channels=reversed_attention_head_dim[i],
212 | dual_cross_attention=dual_cross_attention,
213 | use_linear_projection=use_linear_projection,
214 | only_cross_attention=only_cross_attention[i],
215 | upcast_attention=upcast_attention,
216 | resnet_time_scale_shift=resnet_time_scale_shift,
217 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
218 | unet_use_temporal_attention=unet_use_temporal_attention,
219 | use_inflated_groupnorm=use_inflated_groupnorm,
220 | use_motion_module=res in motion_module_resolutions,
221 | motion_module_kwargs=motion_module_kwargs,
222 | audio_attention_dim=audio_attention_dim,
223 | depth=3 - i,
224 | emo_drop_rate=emo_drop_rate,
225 | is_final_block=is_final_block,
226 | )
227 | self.up_blocks.append(up_block)
228 | prev_output_channel = output_channel
229 |
230 | # out
231 | if use_inflated_groupnorm:
232 | self.conv_norm_out = InflatedGroupNorm(
233 | num_channels=block_out_channels[0],
234 | num_groups=norm_num_groups,
235 | eps=norm_eps,
236 | )
237 | else:
238 | self.conv_norm_out = nn.GroupNorm(
239 | num_channels=block_out_channels[0],
240 | num_groups=norm_num_groups,
241 | eps=norm_eps,
242 | )
243 | self.conv_act = nn.SiLU()
244 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
245 |
246 | @property
247 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
248 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
249 | r"""
250 | Returns:
251 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
252 | indexed by its weight name.
253 | """
254 | # set recursively
255 | processors = {}
256 |
257 | def fn_recursive_add_processors(
258 | name: str,
259 | module: torch.nn.Module,
260 | processors: Dict[str, AttentionProcessor],
261 | ):
262 | if hasattr(module, "set_processor"):
263 | processors[f"{name}.processor"] = module.processor
264 |
265 | for sub_name, child in module.named_children():
266 | if "temporal_transformer" not in sub_name:
267 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
268 |
269 | return processors
270 |
271 | for name, module in self.named_children():
272 | if "temporal_transformer" not in name:
273 | fn_recursive_add_processors(name, module, processors)
274 |
275 | return processors
276 |
277 | def set_attention_slice(self, slice_size):
278 | r"""
279 | Enable sliced attention computation.
280 |
281 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention
282 | in several steps. This is useful to save some memory in exchange for a small speed decrease.
283 |
284 | Args:
285 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
286 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
287 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
288 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
289 | must be a multiple of `slice_size`.
290 | """
291 | sliceable_head_dims = []
292 |
293 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
294 | if hasattr(module, "set_attention_slice"):
295 | sliceable_head_dims.append(module.sliceable_head_dim)
296 |
297 | for child in module.children():
298 | fn_recursive_retrieve_slicable_dims(child)
299 |
300 | # retrieve number of attention layers
301 | for module in self.children():
302 | fn_recursive_retrieve_slicable_dims(module)
303 |
304 | num_slicable_layers = len(sliceable_head_dims)
305 |
306 | if slice_size == "auto":
307 | # half the attention head size is usually a good trade-off between
308 | # speed and memory
309 | slice_size = [dim // 2 for dim in sliceable_head_dims]
310 | elif slice_size == "max":
311 | # make smallest slice possible
312 | slice_size = num_slicable_layers * [1]
313 |
314 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
315 |
316 | if len(slice_size) != len(sliceable_head_dims):
317 | raise ValueError(
318 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
319 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
320 | )
321 |
322 | for i, size in enumerate(slice_size):
323 | dim = sliceable_head_dims[i]
324 | if size is not None and size > dim:
325 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
326 |
327 | # Recursively walk through all the children.
328 | # Any children which exposes the set_attention_slice method
329 | # gets the message
330 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
331 | if hasattr(module, "set_attention_slice"):
332 | module.set_attention_slice(slice_size.pop())
333 |
334 | for child in module.children():
335 | fn_recursive_set_attention_slice(child, slice_size)
336 |
337 | reversed_slice_size = list(reversed(slice_size))
338 | for module in self.children():
339 | fn_recursive_set_attention_slice(module, reversed_slice_size)
340 |
341 | def _set_gradient_checkpointing(self, module, value=False):
342 | if hasattr(module, "gradient_checkpointing"):
343 | module.gradient_checkpointing = value
344 |
345 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
346 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
347 | r"""
348 | Sets the attention processor to use to compute attention.
349 |
350 | Parameters:
351 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
352 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
353 | for **all** `Attention` layers.
354 |
355 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
356 | processor. This is strongly recommended when setting trainable attention processors.
357 |
358 | """
359 | count = len(self.attn_processors.keys())
360 |
361 | if isinstance(processor, dict) and len(processor) != count:
362 | raise ValueError(
363 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
364 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
365 | )
366 |
367 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
368 | if hasattr(module, "set_processor"):
369 | if not isinstance(processor, dict):
370 | module.set_processor(processor)
371 | else:
372 | module.set_processor(processor.pop(f"{name}.processor"))
373 |
374 | for sub_name, child in module.named_children():
375 | if "temporal_transformer" not in sub_name:
376 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
377 |
378 | for name, module in self.named_children():
379 | if "temporal_transformer" not in name:
380 | fn_recursive_attn_processor(name, module, processor)
381 |
382 | def forward(
383 | self,
384 | sample: torch.FloatTensor,
385 | ref_features: dict,
386 | timestep: Union[torch.Tensor, float, int, list],
387 | encoder_hidden_states: torch.Tensor,
388 | audio_embedding: Optional[torch.Tensor] = None,
389 | audio_emotion: Optional[torch.Tensor] = None,
390 | class_labels: Optional[torch.Tensor] = None,
391 | mask_cond_fea: Optional[torch.Tensor] = None,
392 | attention_mask: Optional[torch.Tensor] = None,
393 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
394 | mid_block_additional_residual: Optional[torch.Tensor] = None,
395 | uc_mask: Optional[torch.Tensor] = None,
396 | return_dict: bool = True,
397 | is_new_audio=True,
398 | update_past_memory=False,
399 | ) -> Union[UNet3DConditionOutput, Tuple]:
400 | # By default samples have to be AT least a multiple of the overall upsampling factor.
401 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
402 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
403 | # on the fly if necessary.
404 | default_overall_up_factor = 2**self.num_upsamplers
405 |
406 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
407 | forward_upsample_size = False
408 | upsample_size = None
409 |
410 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
411 | logger.info("Forward upsample size to force interpolation output size.")
412 | forward_upsample_size = True
413 |
414 | # prepare attention_mask
415 | if attention_mask is not None:
416 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
417 | attention_mask = attention_mask.unsqueeze(1)
418 |
419 | # center input if necessary
420 | if self.config.center_input_sample:
421 | sample = 2 * sample - 1.0
422 |
423 | # time
424 | timesteps = timestep
425 | if isinstance(timesteps, list):
426 | t_emb_list = []
427 | for timesteps in timestep:
428 | if not torch.is_tensor(timesteps):
429 | # This would be a good case for the `match` statement (Python 3.10+)
430 | is_mps = sample.device.type == "mps"
431 | if isinstance(timestep, float):
432 | dtype = torch.float32 if is_mps else torch.float64
433 | else:
434 | dtype = torch.int32 if is_mps else torch.int64
435 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
436 | elif len(timesteps.shape) == 0:
437 | timesteps = timesteps[None].to(sample.device)
438 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
439 | timesteps = timesteps.expand(sample.shape[0])
440 | t_emb = self.time_proj(timesteps)
441 | t_emb_list.append(t_emb)
442 |
443 | t_emb = torch.stack(t_emb_list, dim=1)
444 | else:
445 | if not torch.is_tensor(timesteps):
446 | # This would be a good case for the `match` statement (Python 3.10+)
447 | is_mps = sample.device.type == "mps"
448 | if isinstance(timestep, float):
449 | dtype = torch.float32 if is_mps else torch.float64
450 | else:
451 | dtype = torch.int32 if is_mps else torch.int64
452 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
453 | elif len(timesteps.shape) == 0:
454 | timesteps = timesteps[None].to(sample.device)
455 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
456 | timesteps = timesteps.expand(sample.shape[0])
457 | t_emb = self.time_proj(timesteps)
458 |
459 | # timesteps does not contain any weights and will always return f32 tensors
460 | # but time_embedding might actually be running in fp16. so we need to cast here.
461 | # there might be better ways to encapsulate this.
462 | t_emb = t_emb.to(dtype=self.dtype)
463 | emb = self.time_embedding(t_emb)
464 |
465 | if self.class_embedding is not None:
466 | if class_labels is None:
467 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
468 |
469 | if self.config.class_embed_type == "timestep":
470 | class_labels = self.time_proj(class_labels)
471 |
472 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
473 | emb = emb + class_emb
474 |
475 | # pre-process
476 | sample = self.conv_in(sample)
477 | if mask_cond_fea is not None:
478 | sample = sample + mask_cond_fea
479 |
480 | # down
481 | down_block_res_samples = (sample,)
482 | for i, downsample_block in enumerate(self.down_blocks):
483 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
484 | sample, res_samples, audio_embedding = downsample_block(
485 | hidden_states=sample,
486 | ref_feature_list=ref_features["down"][i],
487 | temb=emb,
488 | encoder_hidden_states=encoder_hidden_states,
489 | attention_mask=attention_mask,
490 | audio_embedding=audio_embedding,
491 | emotion=audio_emotion,
492 | uc_mask=uc_mask,
493 | is_new_audio=is_new_audio,
494 | update_past_memory=update_past_memory,
495 | )
496 | else:
497 | sample, res_samples = downsample_block(
498 | hidden_states=sample,
499 | ref_feature_list=ref_features["down"][i],
500 | temb=emb,
501 | encoder_hidden_states=encoder_hidden_states,
502 | is_new_audio=is_new_audio,
503 | update_past_memory=update_past_memory,
504 | )
505 |
506 | down_block_res_samples += res_samples
507 |
508 | if down_block_additional_residuals is not None:
509 | new_down_block_res_samples = ()
510 |
511 | for down_block_res_sample, down_block_additional_residual in zip(
512 | down_block_res_samples, down_block_additional_residuals
513 | ):
514 | down_block_res_sample = down_block_res_sample + down_block_additional_residual
515 | new_down_block_res_samples += (down_block_res_sample,)
516 |
517 | down_block_res_samples = new_down_block_res_samples
518 |
519 | # mid
520 | sample, audio_embedding = self.mid_block(
521 | sample,
522 | ref_feature_list=ref_features["mid"][0],
523 | temb=emb,
524 | encoder_hidden_states=encoder_hidden_states,
525 | attention_mask=attention_mask,
526 | audio_embedding=audio_embedding,
527 | emotion=audio_emotion,
528 | uc_mask=uc_mask,
529 | is_new_audio=is_new_audio,
530 | update_past_memory=update_past_memory,
531 | )
532 |
533 | if mid_block_additional_residual is not None:
534 | sample = sample + mid_block_additional_residual
535 |
536 | # up
537 | for i, upsample_block in enumerate(self.up_blocks):
538 | is_final_block = i == len(self.up_blocks) - 1
539 |
540 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
541 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
542 |
543 | # if we have not reached the final block and need to forward the
544 | # upsample size, we do it here
545 | if not is_final_block and forward_upsample_size:
546 | upsample_size = down_block_res_samples[-1].shape[2:]
547 |
548 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
549 | sample, audio_embedding = upsample_block(
550 | hidden_states=sample,
551 | ref_feature_list=ref_features["up"][i],
552 | temb=emb,
553 | res_hidden_states_tuple=res_samples,
554 | encoder_hidden_states=encoder_hidden_states,
555 | upsample_size=upsample_size,
556 | attention_mask=attention_mask,
557 | audio_embedding=audio_embedding,
558 | emotion=audio_emotion,
559 | uc_mask=uc_mask,
560 | is_new_audio=is_new_audio,
561 | update_past_memory=update_past_memory,
562 | )
563 | else:
564 | sample = upsample_block(
565 | hidden_states=sample,
566 | ref_feature_list=ref_features["up"][i],
567 | temb=emb,
568 | res_hidden_states_tuple=res_samples,
569 | upsample_size=upsample_size,
570 | encoder_hidden_states=encoder_hidden_states,
571 | is_new_audio=is_new_audio,
572 | update_past_memory=update_past_memory,
573 | )
574 |
575 | # post-process
576 | sample = self.conv_norm_out(sample)
577 | sample = self.conv_act(sample)
578 | sample = self.conv_out(sample)
579 |
580 | if not return_dict:
581 | return (sample,)
582 |
583 | return UNet3DConditionOutput(sample=sample)
584 |
--------------------------------------------------------------------------------
/memo/models/wav2vec.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from transformers import Wav2Vec2Model
3 | from transformers.modeling_outputs import BaseModelOutput
4 |
5 |
6 | class Wav2VecModel(Wav2Vec2Model):
7 | def forward(
8 | self,
9 | input_values,
10 | seq_len,
11 | attention_mask=None,
12 | mask_time_indices=None,
13 | output_attentions=None,
14 | output_hidden_states=None,
15 | return_dict=None,
16 | ):
17 | self.config.output_attentions = True
18 |
19 | output_hidden_states = (
20 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
21 | )
22 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
23 |
24 | extract_features = self.feature_extractor(input_values)
25 | extract_features = extract_features.transpose(1, 2)
26 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
27 |
28 | if attention_mask is not None:
29 | # compute reduced attention_mask corresponding to feature vectors
30 | attention_mask = self._get_feature_vector_attention_mask(
31 | extract_features.shape[1], attention_mask, add_adapter=False
32 | )
33 |
34 | hidden_states, extract_features = self.feature_projection(extract_features)
35 | hidden_states = self._mask_hidden_states(
36 | hidden_states,
37 | mask_time_indices=mask_time_indices,
38 | attention_mask=attention_mask,
39 | )
40 |
41 | encoder_outputs = self.encoder(
42 | hidden_states,
43 | attention_mask=attention_mask,
44 | output_attentions=output_attentions,
45 | output_hidden_states=output_hidden_states,
46 | return_dict=return_dict,
47 | )
48 |
49 | hidden_states = encoder_outputs[0]
50 |
51 | if self.adapter is not None:
52 | hidden_states = self.adapter(hidden_states)
53 |
54 | if not return_dict:
55 | return (hidden_states,) + encoder_outputs[1:]
56 | return BaseModelOutput(
57 | last_hidden_state=hidden_states,
58 | hidden_states=encoder_outputs.hidden_states,
59 | attentions=encoder_outputs.attentions,
60 | )
61 |
62 | def feature_extract(
63 | self,
64 | input_values,
65 | seq_len,
66 | ):
67 | extract_features = self.feature_extractor(input_values)
68 | extract_features = extract_features.transpose(1, 2)
69 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
70 |
71 | return extract_features
72 |
73 | def encode(
74 | self,
75 | extract_features,
76 | attention_mask=None,
77 | mask_time_indices=None,
78 | output_attentions=None,
79 | output_hidden_states=None,
80 | return_dict=None,
81 | ):
82 | self.config.output_attentions = True
83 |
84 | output_hidden_states = (
85 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86 | )
87 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88 |
89 | if attention_mask is not None:
90 | # compute reduced attention_mask corresponding to feature vectors
91 | attention_mask = self._get_feature_vector_attention_mask(
92 | extract_features.shape[1], attention_mask, add_adapter=False
93 | )
94 |
95 | hidden_states, extract_features = self.feature_projection(extract_features)
96 | hidden_states = self._mask_hidden_states(
97 | hidden_states,
98 | mask_time_indices=mask_time_indices,
99 | attention_mask=attention_mask,
100 | )
101 |
102 | encoder_outputs = self.encoder(
103 | hidden_states,
104 | attention_mask=attention_mask,
105 | output_attentions=output_attentions,
106 | output_hidden_states=output_hidden_states,
107 | return_dict=return_dict,
108 | )
109 |
110 | hidden_states = encoder_outputs[0]
111 |
112 | if self.adapter is not None:
113 | hidden_states = self.adapter(hidden_states)
114 |
115 | if not return_dict:
116 | return (hidden_states,) + encoder_outputs[1:]
117 | return BaseModelOutput(
118 | last_hidden_state=hidden_states,
119 | hidden_states=encoder_outputs.hidden_states,
120 | attentions=encoder_outputs.attentions,
121 | )
122 |
123 |
124 | def linear_interpolation(features, seq_len):
125 | features = features.transpose(1, 2)
126 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode="linear")
127 | return output_features.transpose(1, 2)
128 |
--------------------------------------------------------------------------------
/memo/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/memo/pipelines/__init__.py
--------------------------------------------------------------------------------
/memo/pipelines/video_pipeline.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, List, Optional, Union
4 |
5 | import numpy as np
6 | import torch
7 | from diffusers import (
8 | DDIMScheduler,
9 | DiffusionPipeline,
10 | DPMSolverMultistepScheduler,
11 | EulerAncestralDiscreteScheduler,
12 | EulerDiscreteScheduler,
13 | LMSDiscreteScheduler,
14 | PNDMScheduler,
15 | )
16 | from diffusers.image_processor import VaeImageProcessor
17 | from diffusers.utils import BaseOutput
18 | from diffusers.utils.torch_utils import randn_tensor
19 | from einops import rearrange
20 |
21 |
22 | @dataclass
23 | class VideoPipelineOutput(BaseOutput):
24 | videos: Union[torch.Tensor, np.ndarray]
25 |
26 |
27 | class VideoPipeline(DiffusionPipeline):
28 | def __init__(
29 | self,
30 | vae,
31 | reference_net,
32 | diffusion_net,
33 | image_proj,
34 | scheduler: Union[
35 | DDIMScheduler,
36 | PNDMScheduler,
37 | LMSDiscreteScheduler,
38 | EulerDiscreteScheduler,
39 | EulerAncestralDiscreteScheduler,
40 | DPMSolverMultistepScheduler,
41 | ],
42 | ) -> None:
43 | super().__init__()
44 |
45 | self.register_modules(
46 | vae=vae,
47 | reference_net=reference_net,
48 | diffusion_net=diffusion_net,
49 | scheduler=scheduler,
50 | image_proj=image_proj,
51 | )
52 |
53 | self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1)
54 |
55 | self.ref_image_processor = VaeImageProcessor(
56 | vae_scale_factor=self.vae_scale_factor,
57 | do_convert_rgb=True,
58 | )
59 |
60 | @property
61 | def _execution_device(self):
62 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
63 | return self.device
64 | for module in self.unet.modules():
65 | if (
66 | hasattr(module, "_hf_hook")
67 | and hasattr(module._hf_hook, "execution_device")
68 | and module._hf_hook.execution_device is not None
69 | ):
70 | return torch.device(module._hf_hook.execution_device)
71 | return self.device
72 |
73 | def prepare_latents(
74 | self,
75 | batch_size: int, # Number of videos to generate in parallel
76 | num_channels_latents: int, # Number of channels in the latents
77 | width: int, # Width of the video frame
78 | height: int, # Height of the video frame
79 | video_length: int, # Length of the video in frames
80 | dtype: torch.dtype, # Data type of the latents
81 | device: torch.device, # Device to store the latents on
82 | generator: Optional[torch.Generator] = None, # Random number generator for reproducibility
83 | latents: Optional[torch.Tensor] = None, # Pre-generated latents (optional)
84 | ):
85 | shape = (
86 | batch_size,
87 | num_channels_latents,
88 | video_length,
89 | height // self.vae_scale_factor,
90 | width // self.vae_scale_factor,
91 | )
92 | if isinstance(generator, list) and len(generator) != batch_size:
93 | raise ValueError(
94 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
95 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
96 | )
97 |
98 | if latents is None:
99 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
100 | else:
101 | latents = latents.to(device)
102 |
103 | # scale the initial noise by the standard deviation required by the scheduler
104 | if hasattr(self.scheduler, "init_noise_sigma"):
105 | latents = latents * self.scheduler.init_noise_sigma
106 | return latents
107 |
108 | def prepare_extra_step_kwargs(self, generator, eta):
109 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
110 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
111 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
112 | # and should be between [0, 1]
113 |
114 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
115 | extra_step_kwargs = {}
116 | if accepts_eta:
117 | extra_step_kwargs["eta"] = eta
118 |
119 | # check if the scheduler accepts generator
120 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
121 | if accepts_generator:
122 | extra_step_kwargs["generator"] = generator
123 | return extra_step_kwargs
124 |
125 | def decode_latents(self, latents):
126 | video_length = latents.shape[2]
127 | latents = 1 / 0.18215 * latents
128 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
129 | video = []
130 | for frame_idx in range(latents.shape[0]):
131 | video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
132 | video = torch.cat(video)
133 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
134 | video = (video / 2 + 0.5).clamp(0, 1)
135 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
136 | video = video.cpu().float().numpy()
137 | return video
138 |
139 | @torch.no_grad()
140 | def __call__(
141 | self,
142 | ref_image,
143 | face_emb,
144 | audio_tensor,
145 | width,
146 | height,
147 | video_length,
148 | num_inference_steps,
149 | guidance_scale,
150 | num_images_per_prompt=1,
151 | eta: float = 0.0,
152 | audio_emotion=None,
153 | emotion_class_num=None,
154 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
155 | output_type: Optional[str] = "tensor",
156 | return_dict: bool = True,
157 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
158 | callback_steps: Optional[int] = 1,
159 | ):
160 | # Default height and width to unet
161 | height = height or self.unet.config.sample_size * self.vae_scale_factor
162 | width = width or self.unet.config.sample_size * self.vae_scale_factor
163 |
164 | device = self._execution_device
165 |
166 | do_classifier_free_guidance = guidance_scale > 1.0
167 |
168 | # Prepare timesteps
169 | self.scheduler.set_timesteps(num_inference_steps, device=device)
170 | timesteps = self.scheduler.timesteps
171 |
172 | batch_size = 1
173 |
174 | # prepare clip image embeddings
175 | clip_image_embeds = face_emb
176 | clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)
177 |
178 | encoder_hidden_states = self.image_proj(clip_image_embeds)
179 | uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))
180 |
181 | if do_classifier_free_guidance:
182 | encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)
183 |
184 | num_channels_latents = self.diffusion_net.in_channels
185 |
186 | latents = self.prepare_latents(
187 | batch_size * num_images_per_prompt,
188 | num_channels_latents,
189 | width,
190 | height,
191 | video_length,
192 | clip_image_embeds.dtype,
193 | device,
194 | generator,
195 | )
196 |
197 | # Prepare extra step kwargs.
198 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
199 |
200 | # Prepare ref image latents
201 | ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
202 | ref_image_tensor = self.ref_image_processor.preprocess(
203 | ref_image_tensor, height=height, width=width
204 | ) # (bs, c, width, height)
205 | ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
206 | # To save memory on GPUs like RTX 4090, we encode each frame separately
207 | # ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
208 | ref_image_latents = []
209 | for frame_idx in range(ref_image_tensor.shape[0]):
210 | ref_image_latents.append(self.vae.encode(ref_image_tensor[frame_idx : frame_idx + 1]).latent_dist.mean)
211 | ref_image_latents = torch.cat(ref_image_latents, dim=0)
212 |
213 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
214 |
215 | if do_classifier_free_guidance:
216 | uncond_audio_tensor = torch.zeros_like(audio_tensor)
217 | audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0)
218 | audio_tensor = audio_tensor.to(dtype=self.diffusion_net.dtype, device=self.diffusion_net.device)
219 |
220 | # denoising loop
221 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
222 | with self.progress_bar(total=num_inference_steps) as progress_bar:
223 | for i in range(len(timesteps)):
224 | t = timesteps[i]
225 | # Forward reference image
226 | if i == 0:
227 | ref_features = self.reference_net(
228 | ref_image_latents.repeat((2 if do_classifier_free_guidance else 1), 1, 1, 1),
229 | torch.zeros_like(t),
230 | encoder_hidden_states=encoder_hidden_states,
231 | return_dict=False,
232 | )
233 |
234 | # expand the latents if we are doing classifier free guidance
235 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
236 | if hasattr(self.scheduler, "scale_model_input"):
237 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
238 |
239 | audio_emotion = torch.tensor(torch.mode(audio_emotion).values.item()).to(
240 | dtype=torch.int, device=self.diffusion_net.device
241 | )
242 | if do_classifier_free_guidance:
243 | uncond_audio_emotion = torch.full_like(audio_emotion, emotion_class_num)
244 | audio_emotion = torch.cat(
245 | [uncond_audio_emotion.unsqueeze(0), audio_emotion.unsqueeze(0)],
246 | dim=0,
247 | )
248 |
249 | uc_mask = (
250 | torch.Tensor(
251 | [1] * batch_size * num_images_per_prompt * 16
252 | + [0] * batch_size * num_images_per_prompt * 16
253 | )
254 | .to(device)
255 | .bool()
256 | )
257 | else:
258 | uc_mask = None
259 |
260 | noise_pred = self.diffusion_net(
261 | latent_model_input,
262 | ref_features,
263 | t,
264 | encoder_hidden_states=encoder_hidden_states,
265 | audio_embedding=audio_tensor,
266 | audio_emotion=audio_emotion,
267 | uc_mask=uc_mask,
268 | ).sample
269 |
270 | # perform guidance
271 | if do_classifier_free_guidance:
272 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
273 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
274 |
275 | # compute the previous noisy sample x_t -> x_t-1
276 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
277 |
278 | # call the callback, if provided
279 | if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
280 | progress_bar.update()
281 | if callback is not None and i % callback_steps == 0:
282 | step_idx = i // getattr(self.scheduler, "order", 1)
283 | callback(step_idx, t, latents)
284 |
285 | # Post-processing
286 | images = self.decode_latents(latents) # (b, c, f, h, w)
287 |
288 | # Convert to tensor
289 | if output_type == "tensor":
290 | images = torch.from_numpy(images)
291 |
292 | if not return_dict:
293 | return images
294 |
295 | return VideoPipelineOutput(videos=images)
296 |
--------------------------------------------------------------------------------
/memo/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/if-ai/ComfyUI-IF_MemoAvatar/bcc480cd79ee97be27de21b5b797356e79116b05/memo/utils/__init__.py
--------------------------------------------------------------------------------
/memo/utils/audio_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import os
4 | import subprocess
5 | from io import BytesIO
6 |
7 | import librosa
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import torchaudio
12 | from audio_separator.separator import Separator
13 | from einops import rearrange
14 | from funasr.download.download_from_hub import download_model
15 | from funasr.models.emotion2vec.model import Emotion2vec
16 | from transformers import Wav2Vec2FeatureExtractor
17 | from safetensors.torch import load_file
18 |
19 | from memo.models.emotion_classifier import AudioEmotionClassifierModel
20 | from memo.models.wav2vec import Wav2VecModel
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int = 16000):
27 | p = subprocess.Popen(
28 | [
29 | "ffmpeg",
30 | "-y",
31 | "-v",
32 | "error",
33 | "-i",
34 | input_audio_file,
35 | "-ar",
36 | str(sample_rate),
37 | output_audio_file,
38 | ]
39 | )
40 | ret = p.wait()
41 | assert ret == 0, f"Resample audio failed! Input: {input_audio_file}, Output: {output_audio_file}"
42 | return output_audio_file
43 |
44 |
45 | @torch.no_grad()
46 | def preprocess_audio(
47 | wav_path: str,
48 | fps: int,
49 | wav2vec_model: str,
50 | vocal_separator_model: str = None,
51 | cache_dir: str = "",
52 | device: str = "cuda",
53 | sample_rate: int = 16000,
54 | num_generated_frames_per_clip: int = -1,
55 | ):
56 | """
57 | Preprocess the audio file and extract audio embeddings.
58 |
59 | Args:
60 | wav_path (str): Path to the input audio file.
61 | fps (int): Frames per second for the audio processing.
62 | wav2vec_model (str): Path to the pretrained Wav2Vec model.
63 | vocal_separator_model (str, optional): Path to the vocal separator model. Defaults to None.
64 | cache_dir (str, optional): Directory for cached files. Defaults to "".
65 | device (str, optional): Device to use ('cuda' or 'cpu'). Defaults to "cuda".
66 | sample_rate (int, optional): Sampling rate for audio processing. Defaults to 16000.
67 | num_generated_frames_per_clip (int, optional): Number of generated frames per clip for padding. Defaults to -1.
68 |
69 | Returns:
70 | tuple: A tuple containing:
71 | - audio_emb (torch.Tensor): The processed audio embeddings.
72 | - audio_length (int): The length of the audio in frames.
73 | """
74 | # Initialize Wav2Vec model
75 | audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model).to(device=device)
76 | audio_encoder.feature_extractor._freeze_parameters()
77 |
78 | # Initialize Wav2Vec feature extractor
79 | wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model)
80 |
81 | # Initialize vocal separator if provided
82 | vocal_separator = None
83 | if vocal_separator_model is not None:
84 | os.makedirs(cache_dir, exist_ok=True)
85 | vocal_separator = Separator(
86 | output_dir=cache_dir,
87 | output_single_stem="vocals",
88 | model_file_dir=os.path.dirname(vocal_separator_model),
89 | )
90 | vocal_separator.load_model(os.path.basename(vocal_separator_model))
91 | assert vocal_separator.model_instance is not None, "Failed to load audio separation model."
92 |
93 | # Perform vocal separation if applicable
94 | if vocal_separator is not None:
95 | outputs = vocal_separator.separate(wav_path)
96 | assert len(outputs) > 0, "Audio separation failed."
97 | vocal_audio_file = outputs[0]
98 | vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
99 | vocal_audio_file = os.path.join(vocal_separator.output_dir, vocal_audio_file)
100 | vocal_audio_file = resample_audio(
101 | vocal_audio_file,
102 | os.path.join(vocal_separator.output_dir, f"{vocal_audio_name}-16k.wav"),
103 | sample_rate,
104 | )
105 | else:
106 | vocal_audio_file = wav_path
107 |
108 | # Load audio and extract Wav2Vec features
109 | speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=sample_rate)
110 | audio_feature = np.squeeze(wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
111 | audio_length = math.ceil(len(audio_feature) / sample_rate * fps)
112 | audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
113 |
114 | # Pad audio features to match the required length
115 | if num_generated_frames_per_clip > 0 and audio_length % num_generated_frames_per_clip != 0:
116 | audio_feature = torch.nn.functional.pad(
117 | audio_feature,
118 | (
119 | 0,
120 | (num_generated_frames_per_clip - audio_length % num_generated_frames_per_clip) * (sample_rate // fps),
121 | ),
122 | "constant",
123 | 0.0,
124 | )
125 | audio_length += num_generated_frames_per_clip - audio_length % num_generated_frames_per_clip
126 | audio_feature = audio_feature.unsqueeze(0)
127 |
128 | # Extract audio embeddings
129 | with torch.no_grad():
130 | embeddings = audio_encoder(audio_feature, seq_len=audio_length, output_hidden_states=True)
131 | assert len(embeddings) > 0, "Failed to extract audio embeddings."
132 | audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
133 | audio_emb = rearrange(audio_emb, "b s d -> s b d")
134 |
135 | # Concatenate embeddings with surrounding frames
136 | audio_emb = audio_emb.cpu().detach()
137 | concatenated_tensors = []
138 | for i in range(audio_emb.shape[0]):
139 | vectors_to_concat = [audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)] for j in range(-2, 3)]
140 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
141 | audio_emb = torch.stack(concatenated_tensors, dim=0)
142 |
143 | if vocal_separator is not None:
144 | del vocal_separator
145 | del audio_encoder
146 |
147 | return audio_emb, audio_length
148 |
149 |
150 | @torch.no_grad()
151 | def extract_audio_emotion_labels(
152 | model: str,
153 | wav_path: str,
154 | emotion2vec_model: str,
155 | audio_length: int,
156 | sample_rate: int = 16000,
157 | device: str = "cuda",
158 | ):
159 | """
160 | Extract audio emotion labels from an audio file.
161 | """
162 | # Load models
163 | logger.info("Downloading emotion2vec models from modelscope")
164 | kwargs = download_model(model=emotion2vec_model)
165 | kwargs["tokenizer"] = None
166 | kwargs["input_size"] = None
167 | kwargs["frontend"] = None
168 | emotion_model = Emotion2vec(**kwargs, vocab_size=-1).to(device)
169 | init_param = kwargs.get("init_param", None)
170 | load_emotion2vec_model(
171 | model=emotion_model,
172 | path=init_param,
173 | ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
174 | oss_bucket=kwargs.get("oss_bucket", None),
175 | scope_map=kwargs.get("scope_map", []),
176 | )
177 | emotion_model.eval()
178 |
179 | # Create and load emotion classifier directly
180 | classifier = AudioEmotionClassifierModel().to(device=device)
181 | emotion_classifier_path = os.path.join(
182 | model,
183 | "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors"
184 | )
185 | classifier.load_state_dict(load_file(emotion_classifier_path))
186 | classifier.eval()
187 |
188 | # Load audio
189 | wav, sr = torchaudio.load(wav_path)
190 | if sr != sample_rate:
191 | wav = torchaudio.functional.resample(wav, sr, sample_rate)
192 | wav = wav.view(-1) if wav.dim() == 1 else wav[0].view(-1)
193 |
194 | emotion_labels = torch.full_like(wav, -1, dtype=torch.int32)
195 |
196 | def extract_emotion(x):
197 | """
198 | Extract emotion for a given audio segment.
199 | """
200 | x = x.to(device=device)
201 | x = F.layer_norm(x, x.shape).view(1, -1)
202 | feats = emotion_model.extract_features(x)
203 | x = feats["x"].mean(dim=1) # average across frames
204 | x = classifier(x)
205 | x = torch.softmax(x, dim=-1)
206 | return torch.argmax(x, dim=-1)
207 |
208 | # Process start, middle, and end segments
209 | start_label = extract_emotion(wav[: sample_rate * 2]).item()
210 | emotion_labels[:sample_rate] = start_label
211 |
212 | for i in range(sample_rate, len(wav) - sample_rate, sample_rate):
213 | mid_wav = wav[i - sample_rate : i - sample_rate + sample_rate * 3]
214 | mid_label = extract_emotion(mid_wav).item()
215 | emotion_labels[i : i + sample_rate] = mid_label
216 |
217 | end_label = extract_emotion(wav[-sample_rate * 2 :]).item()
218 | emotion_labels[-sample_rate:] = end_label
219 |
220 | # Interpolate to match the target audio length
221 | emotion_labels = emotion_labels.unsqueeze(0).unsqueeze(0).float()
222 | emotion_labels = F.interpolate(emotion_labels, size=audio_length, mode="nearest").squeeze(0).squeeze(0).int()
223 | num_emotion_classes = classifier.num_emotion_classes
224 |
225 | del emotion_model
226 | del classifier
227 |
228 | return emotion_labels, num_emotion_classes
229 |
230 |
231 | def load_emotion2vec_model(
232 | path: str,
233 | model: torch.nn.Module,
234 | ignore_init_mismatch: bool = True,
235 | map_location: str = "cpu",
236 | oss_bucket=None,
237 | scope_map=[],
238 | ):
239 | obj = model
240 | dst_state = obj.state_dict()
241 | logger.debug(f"Emotion2vec checkpoint: {path}")
242 | if oss_bucket is None:
243 | src_state = torch.load(path, map_location=map_location)
244 | else:
245 | buffer = BytesIO(oss_bucket.get_object(path).read())
246 | src_state = torch.load(buffer, map_location=map_location)
247 |
248 | src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
249 | src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
250 | src_state = src_state["model"] if "model" in src_state else src_state
251 |
252 | if isinstance(scope_map, str):
253 | scope_map = scope_map.split(",")
254 | scope_map += ["module.", "None"]
255 |
256 | for k in dst_state.keys():
257 | k_src = k
258 | if scope_map is not None:
259 | src_prefix = ""
260 | dst_prefix = ""
261 | for i in range(0, len(scope_map), 2):
262 | src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
263 | dst_prefix = scope_map[i + 1] if scope_map[i + 1].lower() != "none" else ""
264 |
265 | if dst_prefix == "" and (src_prefix + k) in src_state.keys():
266 | k_src = src_prefix + k
267 | if not k_src.startswith("module."):
268 | logger.debug(f"init param, map: {k} from {k_src} in ckpt")
269 | elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys():
270 | k_src = k.replace(dst_prefix, src_prefix, 1)
271 | if not k_src.startswith("module."):
272 | logger.debug(f"init param, map: {k} from {k_src} in ckpt")
273 |
274 | if k_src in src_state.keys():
275 | if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
276 | logger.debug(
277 | f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
278 | )
279 | else:
280 | dst_state[k] = src_state[k_src]
281 |
282 | else:
283 | logger.debug(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
284 |
285 | obj.load_state_dict(dst_state, strict=True)
286 |
--------------------------------------------------------------------------------
/memo/utils/vision_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | from insightface.app import FaceAnalysis
8 | from moviepy.editor import AudioFileClip, VideoClip
9 | from PIL import Image
10 | from torchvision import transforms
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def tensor_to_video(tensor, output_video_path, input_audio_path, fps=30):
17 | """
18 | Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
19 |
20 | Args:
21 | tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w].
22 | output_video_path (str): The file path where the output video will be saved.
23 | input_audio_path (str): The path to the audio file (WAV file) that contains the audio track to be added.
24 | fps (int): The frame rate of the output video. Default is 30 fps.
25 | """
26 | tensor = tensor.permute(1, 2, 3, 0).cpu().numpy() # convert to [f, h, w, c]
27 | tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8) # to [0, 255]
28 |
29 | def make_frame(t):
30 | frame_index = min(int(t * fps), tensor.shape[0] - 1)
31 | return tensor[frame_index]
32 |
33 | video_duration = tensor.shape[0] / fps
34 | audio_clip = AudioFileClip(input_audio_path)
35 | audio_duration = audio_clip.duration
36 | final_duration = min(video_duration, audio_duration)
37 | audio_clip = audio_clip.subclip(0, final_duration)
38 | new_video_clip = VideoClip(make_frame, duration=final_duration)
39 | new_video_clip = new_video_clip.set_audio(audio_clip)
40 | new_video_clip.write_videofile(output_video_path, fps=fps, audio_codec="aac")
41 |
42 |
43 | @torch.no_grad()
44 | def preprocess_image(face_analysis_model, image_path, image_size):
45 | """Preprocess image for MEMO pipeline"""
46 | # Modify face analysis initialization
47 | face_analysis = FaceAnalysis(
48 | name="",
49 | root=face_analysis_model, # Use parent directory
50 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
51 | )
52 | face_analysis.prepare(ctx_id=0, det_size=(640, 640))
53 |
54 | # Define the image transformation
55 | transform = transforms.Compose(
56 | [
57 | transforms.Resize((image_size, image_size)),
58 | transforms.ToTensor(),
59 | transforms.Normalize([0.5], [0.5]),
60 | ]
61 | )
62 |
63 | # Load and preprocess the image
64 | image = Image.open(image_path).convert("RGB")
65 | pixel_values = transform(image)
66 | pixel_values = pixel_values.unsqueeze(0)
67 |
68 | # Detect faces and extract the face embedding
69 | image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
70 | faces = face_analysis.get(image_bgr)
71 | if not faces:
72 | logger.warning("No faces detected in the image. Using a zero vector as the face embedding.")
73 | face_emb = np.zeros(512)
74 | else:
75 | # Sort faces by size and select the largest one
76 | faces_sorted = sorted(
77 | faces,
78 | key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]),
79 | reverse=True,
80 | )
81 | if "embedding" not in faces_sorted[0]:
82 | logger.warning("The detected face does not have an 'embedding'. Using a zero vector.")
83 | face_emb = np.zeros(512)
84 | else:
85 | face_emb = faces_sorted[0]["embedding"]
86 |
87 | # Convert face embedding to a PyTorch tensor
88 | face_emb = face_emb.reshape(1, -1)
89 | face_emb = torch.tensor(face_emb)
90 |
91 | del face_analysis
92 |
93 | return pixel_values, face_emb
94 |
--------------------------------------------------------------------------------
/memo_model_manager.py:
--------------------------------------------------------------------------------
1 | #memo_model_manager.py
2 | import os
3 | import logging
4 | import json
5 | import shutil
6 | from huggingface_hub import hf_hub_download
7 | from modelscope import snapshot_download
8 | import folder_paths
9 |
10 | logger = logging.getLogger("memo")
11 |
12 | class MemoModelManager:
13 | def __init__(self):
14 | self.models_base = folder_paths.models_dir
15 | self._setup_paths()
16 | self._ensure_model_structure()
17 |
18 | def _setup_paths(self):
19 | """Initialize base paths structure"""
20 | self.paths = {
21 | "memo_base": os.path.join(self.models_base, "checkpoints", "memo"),
22 | "wav2vec": os.path.join(self.models_base, "wav2vec", "facebook", "wav2vec2-base-960h"),
23 | "emotion2vec": os.path.join(self.models_base, "emotion2vec", "iic", "emotion2vec_plus_large"),
24 | "vae": os.path.join(self.models_base, "vae", "stabilityai", "sd-vae-ft-mse")
25 | }
26 |
27 | # Create directories
28 | for path in self.paths.values():
29 | os.makedirs(path, exist_ok=True)
30 |
31 | # Create memo subfolders
32 | for subdir in ["reference_net", "diffusion_net", "image_proj", "audio_proj",
33 | "misc/audio_emotion_classifier", "misc/face_analysis", "misc/vocal_separator"]:
34 | os.makedirs(os.path.join(self.paths["memo_base"], subdir), exist_ok=True)
35 |
36 | def _direct_download(self, repo_id, filename, target_path, force=False):
37 | """Download directly to target path without extra nesting"""
38 | try:
39 | if not force and os.path.exists(target_path):
40 | return target_path
41 |
42 | download_path = hf_hub_download(
43 | repo_id=repo_id,
44 | filename=filename,
45 | local_dir=os.path.dirname(target_path),
46 | local_dir_use_symlinks=False
47 | )
48 |
49 | # Move if downloaded to wrong location
50 | if download_path != target_path:
51 | shutil.move(download_path, target_path)
52 |
53 | return target_path
54 | except Exception as e:
55 | logger.warning(f"Failed to download {filename} from {repo_id} to {target_path}: {e}")
56 | return None
57 |
58 | def _setup_face_analysis(self):
59 | """Setup face analysis models with correct structure"""
60 | face_dir = os.path.join(self.paths["memo_base"], "misc", "face_analysis")
61 | models_dir = os.path.join(face_dir, "models") # Create a models subdirectory
62 | os.makedirs(models_dir, exist_ok=True)
63 |
64 | # Create models.json
65 | models_json = {
66 | "detection": ["scrfd_10g_bnkps"],
67 | "recognition": ["glintr100"],
68 | "analysis": ["genderage", "2d106det", "1k3d68"]
69 | }
70 |
71 | # Write models.json
72 | models_json_path = os.path.join(face_dir, "models.json")
73 | with open(models_json_path, "w") as f:
74 | json.dump(models_json, f, indent=2)
75 |
76 | # Create version.txt
77 | version_path = os.path.join(face_dir, "version.txt")
78 | with open(version_path, "w") as f:
79 | f.write("0.7.3")
80 |
81 | # Download model files if they don't exist
82 | required_models = {
83 | "scrfd_10g_bnkps.onnx": "scrfd_10g_bnkps",
84 | "glintr100.onnx": "glintr100",
85 | "genderage.onnx": "genderage",
86 | "2d106det.onnx": "2d106det",
87 | "1k3d68.onnx": "1k3d68",
88 | "face_landmarker_v2_with_blendshapes.task": "face_landmarker"
89 | }
90 |
91 | for model_file, model_name in required_models.items():
92 | target_path = os.path.join(models_dir, model_file) # Save in models subdirectory
93 | if not os.path.exists(target_path):
94 | self._direct_download(
95 | "memoavatar/memo",
96 | f"misc/face_analysis/models/{model_file}",
97 | target_path
98 | )
99 | # Create symlink in parent directory for compatibility
100 | parent_target = os.path.join(face_dir, model_file)
101 | if not os.path.exists(parent_target):
102 | if os.name == 'nt': # Windows
103 | import shutil
104 | shutil.copy2(target_path, parent_target)
105 | else: # Unix-like
106 | os.symlink(target_path, parent_target)
107 |
108 | # Set environment variable for face models
109 | os.environ["MEMO_FACE_MODELS"] = face_dir
110 | return face_dir
111 |
112 | def _ensure_model_structure(self):
113 | """Download all required models to correct locations"""
114 | # Set up face analysis and environment variables first
115 | face_dir = self._setup_face_analysis()
116 | os.environ["MEMO_FACE_MODELS"] = face_dir
117 | os.environ["MEMO_VOCAL_MODEL"] = os.path.join(
118 | self.paths["memo_base"], "misc/vocal_separator/Kim_Vocal_2.onnx"
119 | )
120 |
121 | # Download memo components
122 | components = {
123 | "reference_net": ["config.json", "diffusion_pytorch_model.safetensors"],
124 | "diffusion_net": ["config.json", "diffusion_pytorch_model.safetensors"],
125 | "image_proj": ["config.json", "diffusion_pytorch_model.safetensors"],
126 | "audio_proj": ["config.json", "diffusion_pytorch_model.safetensors"]
127 | }
128 |
129 | for component, files in components.items():
130 | component_dir = os.path.join(self.paths["memo_base"], component)
131 | for file in files:
132 | self._direct_download(
133 | "memoavatar/memo",
134 | f"{component}/{file}",
135 | os.path.join(component_dir, file)
136 | )
137 |
138 | # Download vocal separator
139 | self._direct_download(
140 | "memoavatar/memo",
141 | "misc/vocal_separator/Kim_Vocal_2.onnx",
142 | os.path.join(self.paths["memo_base"], "misc/vocal_separator/Kim_Vocal_2.onnx")
143 | )
144 |
145 | # Download emotion classifier
146 | self._direct_download(
147 | "memoavatar/memo",
148 | "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors",
149 | os.path.join(self.paths["memo_base"], "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors")
150 | )
151 |
152 | # Download wav2vec files
153 | for file in ["config.json", "preprocessor_config.json", "pytorch_model.bin",
154 | "special_tokens_map.json", "tokenizer_config.json", "vocab.json"]:
155 | self._direct_download(
156 | "facebook/wav2vec2-base-960h",
157 | file,
158 | os.path.join(self.paths["wav2vec"], file)
159 | )
160 |
161 | # Download emotion2vec
162 | try:
163 | snapshot_download(
164 | "iic/emotion2vec_plus_large",
165 | local_dir=self.paths["emotion2vec"]
166 | )
167 | except Exception as e:
168 | logger.warning(f"Failed to download emotion2vec model: {e}")
169 |
170 | # Download VAE
171 | for file in ["config.json", "diffusion_pytorch_model.safetensors"]:
172 | self._direct_download(
173 | "stabilityai/sd-vae-ft-mse",
174 | file,
175 | os.path.join(self.paths["vae"], file)
176 | )
177 |
178 | def get_model_paths(self):
179 | """Return paths dictionary"""
180 | return {
181 | "memo_base": self.paths["memo_base"],
182 | "face_models": os.path.join(self.paths["memo_base"], "misc/face_analysis"),
183 | "vocal_separator": os.path.join(self.paths["memo_base"], "misc/vocal_separator/Kim_Vocal_2.onnx"),
184 | "wav2vec": self.paths["wav2vec"],
185 | "emotion2vec": self.paths["emotion2vec"],
186 | "vae": self.paths["vae"]
187 | }
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui_if_memoavatar"
3 | description = "Talking avatars MemoAvatar Memory-Guided Diffusion for Expressive Talking Video Generation"
4 | version = "0.0.3"
5 | license = { file = "MIT License" }
6 | dependencies = [
7 | "albumentations==1.4.21",
8 | "modelscope==1.20.1",
9 | "numba==0.60.0",
10 | "librosa==0.10.2",
11 | "diffusers>=0.31.0",
12 | "transformers>=4.46.3",
13 | "numpy>=1.26.4",
14 | "PyYAML>=6.0.1",
15 | "moviepy==1.0.3",
16 | "pillow>=10.4.0",
17 | "funasr==1.0.27",
18 | "insightface==0.7.3",
19 | "accelerate==1.1.1",
20 | "black==23.12.1",
21 | "ffmpeg-python==0.2.0",
22 | "huggingface-hub==0.26.2",
23 | "imageio==2.36.0",
24 | "imageio-ffmpeg==0.5.1",
25 | "hydra-core==1.3.2",
26 | "jax==0.4.35",
27 | "mediapipe==0.10.18",
28 | "omegaconf==2.3.0",
29 | "onnxruntime-gpu>=1.20.1",
30 | "opencv-python-headless==4.10.0.84",
31 | "scikit-learn>=1.5.2",
32 | "scipy>=1.14.1",
33 | "tqdm>=4.67.1",
34 | ]
35 |
36 | [project.urls]
37 | Repository = "https://github.com/if-ai/ComfyUI-IF_MemoAvatar"
38 |
39 | [tool.comfy]
40 | PublisherId = "impactframes"
41 | DisplayName = "IF_MemoAvatar"
42 | Icon = "https://impactframes.ai/System/Icons/48x48/if.png"
43 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git-lfs
2 | diffusers>=0.31.0
3 | audio-separator
4 | albumentations
5 | numba
6 | librosa
7 | modelscope
8 | transformers>=4.46.3
9 | numpy>=1.26.4
10 | PyYAML>=6.0.1
11 | moviepy>=1.0.3
12 | pillow>=10.4.0
13 | librosa==0.10.2
14 | audio-separator==0.24.1
15 | funasr==1.0.27
16 | modelscope
17 | insightface==0.7.3
18 | accelerate==1.1.1
19 | albumentations==1.4.21
20 | black==23.12.1
21 | einops==0.8.0
22 | ffmpeg-python==0.2.0
23 | huggingface-hub==0.26.2
24 | imageio==2.36.0
25 | imageio-ffmpeg==0.5.1
26 | hydra-core==1.3.2
27 | jax==0.4.35
28 | mediapipe==0.10.18
29 | modelscope==1.20.1
30 | omegaconf==2.3.0
31 | onnxruntime>=1.20.1
32 | onnxruntime-gpu>=1.20.1
33 | opencv-python-headless==4.10.0.84
34 | scikit-learn>=1.5.2
35 | scipy>=1.14.1
36 | tqdm>=4.67.1
37 |
--------------------------------------------------------------------------------
/web/js/IF_MemoAvatar.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 |
3 | // Create base styles for buttons
4 | const style = document.createElement('style');
5 | style.textContent = `
6 | .if-button {
7 | background: var(--comfy-input-bg);
8 | border: 1px solid var(--border-color);
9 | color: var(--input-text);
10 | padding: 4px 12px;
11 | border-radius: 4px;
12 | cursor: pointer;
13 | transition: all 0.2s ease;
14 | margin-right: 5px;
15 | }
16 |
17 | .if-button:hover {
18 | background: var(--comfy-input-bg-hover);
19 | }
20 | `;
21 | document.head.appendChild(style);
22 |
23 | app.registerExtension({
24 | name: "Comfy.IF_MemoAvatar",
25 | async beforeRegisterNodeDef(nodeType, nodeData) {
26 | if (nodeData.name !== "IF_MemoAvatar") return;
27 |
28 | const origOnNodeCreated = nodeType.prototype.onNodeCreated;
29 |
30 | nodeType.prototype.onNodeCreated = function() {
31 | const result = origOnNodeCreated?.apply(this, arguments);
32 |
33 | // Add preview widget
34 | this.addWidget("preview", "preview", {
35 | serialize: false,
36 | size: [256, 256]
37 | });
38 |
39 | // Ensure proper widget parent references
40 | if (this.widgets) {
41 | this.widgets.forEach(w => w.parent = this);
42 | }
43 |
44 | return result;
45 | };
46 |
47 | // Add size handling
48 | nodeType.prototype.onResize = function(size) {
49 | const minWidth = 400;
50 | const minHeight = 200;
51 | size[0] = Math.max(size[0], minWidth);
52 | size[1] = Math.max(size[1], minHeight);
53 | };
54 |
55 | // Add execution handling
56 | nodeType.prototype.onExecuted = function(message) {
57 | if (message.preview) {
58 | const previewWidget = this.widgets.find(w => w.name === "preview");
59 | if (previewWidget) {
60 | previewWidget.value = message.preview;
61 | }
62 | }
63 | };
64 | }
65 | });
--------------------------------------------------------------------------------
/workflow/IF_MemoAvatar.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 67,
3 | "last_link_id": 114,
4 | "nodes": [
5 | {
6 | "id": 49,
7 | "type": "Note",
8 | "pos": [
9 | 1176.2391357421875,
10 | 654.4822387695312
11 | ],
12 | "size": [
13 | 210,
14 | 156.80001831054688
15 | ],
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "inputs": [],
20 | "outputs": [],
21 | "properties": {},
22 | "widgets_values": [
23 | "This does not populate but the video get saved on the output folder \nMaybe load it up with VHS node"
24 | ],
25 | "color": "#432",
26 | "bgcolor": "#653"
27 | },
28 | {
29 | "id": 56,
30 | "type": "PreviewImage",
31 | "pos": [
32 | -174.07064819335938,
33 | 707.8036499023438
34 | ],
35 | "size": [
36 | 211.03074645996094,
37 | 246
38 | ],
39 | "flags": {},
40 | "order": 8,
41 | "mode": 0,
42 | "inputs": [
43 | {
44 | "name": "images",
45 | "type": "IMAGE",
46 | "link": 90
47 | }
48 | ],
49 | "outputs": [],
50 | "properties": {
51 | "Node name for S&R": "PreviewImage"
52 | },
53 | "widgets_values": []
54 | },
55 | {
56 | "id": 54,
57 | "type": "ImageResizeKJ",
58 | "pos": [
59 | -164.07064819335938,
60 | 157.803466796875
61 | ],
62 | "size": [
63 | 210,
64 | 238
65 | ],
66 | "flags": {
67 | "collapsed": false
68 | },
69 | "order": 6,
70 | "mode": 0,
71 | "inputs": [
72 | {
73 | "name": "image",
74 | "type": "IMAGE",
75 | "link": 84
76 | },
77 | {
78 | "name": "get_image_size",
79 | "type": "IMAGE",
80 | "link": null,
81 | "shape": 7
82 | },
83 | {
84 | "name": "width_input",
85 | "type": "INT",
86 | "link": null,
87 | "widget": {
88 | "name": "width_input"
89 | },
90 | "shape": 7
91 | },
92 | {
93 | "name": "height_input",
94 | "type": "INT",
95 | "link": null,
96 | "widget": {
97 | "name": "height_input"
98 | },
99 | "shape": 7
100 | }
101 | ],
102 | "outputs": [
103 | {
104 | "name": "IMAGE",
105 | "type": "IMAGE",
106 | "links": [
107 | 89
108 | ],
109 | "slot_index": 0
110 | },
111 | {
112 | "name": "width",
113 | "type": "INT",
114 | "links": null
115 | },
116 | {
117 | "name": "height",
118 | "type": "INT",
119 | "links": null
120 | }
121 | ],
122 | "properties": {
123 | "Node name for S&R": "ImageResizeKJ"
124 | },
125 | "widgets_values": [
126 | 1504,
127 | 1504,
128 | "lanczos",
129 | true,
130 | 2,
131 | 0,
132 | 0,
133 | "disabled"
134 | ]
135 | },
136 | {
137 | "id": 57,
138 | "type": "ImageCrop+",
139 | "pos": [
140 | -174.07064819335938,
141 | 457.80340576171875
142 | ],
143 | "size": [
144 | 210,
145 | 194
146 | ],
147 | "flags": {},
148 | "order": 7,
149 | "mode": 0,
150 | "inputs": [
151 | {
152 | "name": "image",
153 | "type": "IMAGE",
154 | "link": 89
155 | }
156 | ],
157 | "outputs": [
158 | {
159 | "name": "IMAGE",
160 | "type": "IMAGE",
161 | "links": [
162 | 90,
163 | 112
164 | ],
165 | "slot_index": 0
166 | },
167 | {
168 | "name": "x",
169 | "type": "INT",
170 | "links": null
171 | },
172 | {
173 | "name": "y",
174 | "type": "INT",
175 | "links": null
176 | }
177 | ],
178 | "properties": {
179 | "Node name for S&R": "ImageCrop+"
180 | },
181 | "widgets_values": [
182 | 1400,
183 | 1400,
184 | "top-center",
185 | -2,
186 | 51
187 | ]
188 | },
189 | {
190 | "id": 59,
191 | "type": "Fast Groups Muter (rgthree)",
192 | "pos": [
193 | 70.64138793945312,
194 | 937.8055419921875
195 | ],
196 | "size": [
197 | 226.8000030517578,
198 | 106
199 | ],
200 | "flags": {},
201 | "order": 1,
202 | "mode": 0,
203 | "inputs": [],
204 | "outputs": [
205 | {
206 | "name": "OPT_CONNECTION",
207 | "type": "*",
208 | "links": null
209 | }
210 | ],
211 | "properties": {
212 | "matchColors": "",
213 | "matchTitle": "",
214 | "showNav": true,
215 | "sort": "position",
216 | "customSortAlphabet": "",
217 | "toggleRestriction": "default"
218 | }
219 | },
220 | {
221 | "id": 32,
222 | "type": "IF_MemoCheckpointLoader",
223 | "pos": [
224 | 72.47502136230469,
225 | 424.5469665527344
226 | ],
227 | "size": [
228 | 315,
229 | 206
230 | ],
231 | "flags": {},
232 | "order": 2,
233 | "mode": 0,
234 | "inputs": [],
235 | "outputs": [
236 | {
237 | "name": "reference_net",
238 | "type": "MODEL",
239 | "links": [
240 | 100
241 | ],
242 | "slot_index": 0
243 | },
244 | {
245 | "name": "diffusion_net",
246 | "type": "MODEL",
247 | "links": [
248 | 101
249 | ],
250 | "slot_index": 1
251 | },
252 | {
253 | "name": "vae",
254 | "type": "VAE",
255 | "links": [
256 | 102
257 | ],
258 | "slot_index": 2
259 | },
260 | {
261 | "name": "image_proj",
262 | "type": "IMAGE_PROJ",
263 | "links": [
264 | 103
265 | ],
266 | "slot_index": 3
267 | },
268 | {
269 | "name": "audio_proj",
270 | "type": "AUDIO_PROJ",
271 | "links": [
272 | 104
273 | ],
274 | "slot_index": 4
275 | },
276 | {
277 | "name": "emotion_classifier",
278 | "type": "EMOTION_CLASSIFIER",
279 | "links": [
280 | 105
281 | ],
282 | "slot_index": 5
283 | }
284 | ],
285 | "properties": {
286 | "Node name for S&R": "IF_MemoCheckpointLoader"
287 | },
288 | "widgets_values": [
289 | "memoavatar/memo"
290 | ]
291 | },
292 | {
293 | "id": 60,
294 | "type": "LoadAudio",
295 | "pos": [
296 | 68.33649444580078,
297 | 675.4642944335938
298 | ],
299 | "size": [
300 | 315,
301 | 124
302 | ],
303 | "flags": {},
304 | "order": 3,
305 | "mode": 0,
306 | "inputs": [],
307 | "outputs": [
308 | {
309 | "name": "AUDIO",
310 | "type": "AUDIO",
311 | "links": [
312 | 111
313 | ],
314 | "slot_index": 0
315 | }
316 | ],
317 | "properties": {
318 | "Node name for S&R": "LoadAudio"
319 | },
320 | "widgets_values": [
321 | "Hola.wav",
322 | null,
323 | ""
324 | ]
325 | },
326 | {
327 | "id": 50,
328 | "type": "Note",
329 | "pos": [
330 | 446.33880615234375,
331 | 203.38783264160156
332 | ],
333 | "size": [
334 | 331.2768249511719,
335 | 117.3055648803711
336 | ],
337 | "flags": {},
338 | "order": 4,
339 | "mode": 0,
340 | "inputs": [],
341 | "outputs": [],
342 | "properties": {},
343 | "widgets_values": [
344 | "3090 worked with 384 faster\n512 slower \nreducing fps braked the video but needs testing and optimizations \nThis is a wrapper is using diffusers and same code as the Original Repo I only made a few changes it should work on 3090 or maybe even in lesser cards if chosen a smaller res"
345 | ],
346 | "color": "#432",
347 | "bgcolor": "#653"
348 | },
349 | {
350 | "id": 31,
351 | "type": "IF_DisplayText",
352 | "pos": [
353 | 851.056640625,
354 | 660.5128784179688
355 | ],
356 | "size": [
357 | 295.78253173828125,
358 | 228.17208862304688
359 | ],
360 | "flags": {},
361 | "order": 11,
362 | "mode": 0,
363 | "inputs": [
364 | {
365 | "name": "text",
366 | "type": "STRING",
367 | "link": 108,
368 | "widget": {
369 | "name": "text"
370 | }
371 | }
372 | ],
373 | "outputs": [
374 | {
375 | "name": "text",
376 | "type": "STRING",
377 | "links": null,
378 | "tooltip": "Full text content"
379 | },
380 | {
381 | "name": "text_list",
382 | "type": "STRING",
383 | "links": null,
384 | "shape": 6,
385 | "tooltip": "Individual lines as separate outputs"
386 | },
387 | {
388 | "name": "count",
389 | "type": "INT",
390 | "links": null,
391 | "tooltip": "Total number of non-empty lines"
392 | },
393 | {
394 | "name": "selected",
395 | "type": "STRING",
396 | "links": null,
397 | "tooltip": "Currently selected line based on select input"
398 | }
399 | ],
400 | "properties": {
401 | "Node name for S&R": "IF_DisplayText"
402 | },
403 | "widgets_values": [
404 | "",
405 | 0,
406 | "✅ Video saved as memo_video_20241216-193105.mp4"
407 | ]
408 | },
409 | {
410 | "id": 61,
411 | "type": "IF_MemoAvatar",
412 | "pos": [
413 | 424.6819152832031,
414 | 392.4902648925781
415 | ],
416 | "size": [
417 | 400,
418 | 390
419 | ],
420 | "flags": {},
421 | "order": 9,
422 | "mode": 0,
423 | "inputs": [
424 | {
425 | "name": "image",
426 | "type": "IMAGE",
427 | "link": 112
428 | },
429 | {
430 | "name": "audio",
431 | "type": "AUDIO",
432 | "link": 111
433 | },
434 | {
435 | "name": "reference_net",
436 | "type": "MODEL",
437 | "link": 100
438 | },
439 | {
440 | "name": "diffusion_net",
441 | "type": "MODEL",
442 | "link": 101
443 | },
444 | {
445 | "name": "vae",
446 | "type": "VAE",
447 | "link": 102
448 | },
449 | {
450 | "name": "image_proj",
451 | "type": "IMAGE_PROJ",
452 | "link": 103
453 | },
454 | {
455 | "name": "audio_proj",
456 | "type": "AUDIO_PROJ",
457 | "link": 104
458 | },
459 | {
460 | "name": "emotion_classifier",
461 | "type": "EMOTION_CLASSIFIER",
462 | "link": 105
463 | }
464 | ],
465 | "outputs": [
466 | {
467 | "name": "video_path",
468 | "type": "STRING",
469 | "links": [
470 | 114
471 | ],
472 | "slot_index": 0
473 | },
474 | {
475 | "name": "status",
476 | "type": "STRING",
477 | "links": [
478 | 108
479 | ],
480 | "slot_index": 1
481 | }
482 | ],
483 | "properties": {
484 | "Node name for S&R": "IF_MemoAvatar"
485 | },
486 | "widgets_values": [
487 | 512,
488 | 16,
489 | 30,
490 | 20,
491 | 3.5,
492 | 35,
493 | "randomize",
494 | "memo_video",
495 | {
496 | "serialize": false,
497 | "size": [
498 | 256,
499 | 256
500 | ]
501 | }
502 | ]
503 | },
504 | {
505 | "id": 67,
506 | "type": "IF_DisplayText",
507 | "pos": [
508 | 855.5243530273438,
509 | 369.9153137207031
510 | ],
511 | "size": [
512 | 295.78253173828125,
513 | 228.17208862304688
514 | ],
515 | "flags": {},
516 | "order": 10,
517 | "mode": 0,
518 | "inputs": [
519 | {
520 | "name": "text",
521 | "type": "STRING",
522 | "link": 114,
523 | "widget": {
524 | "name": "text"
525 | }
526 | }
527 | ],
528 | "outputs": [
529 | {
530 | "name": "text",
531 | "type": "STRING",
532 | "links": null,
533 | "tooltip": "Full text content"
534 | },
535 | {
536 | "name": "text_list",
537 | "type": "STRING",
538 | "links": null,
539 | "shape": 6,
540 | "tooltip": "Individual lines as separate outputs"
541 | },
542 | {
543 | "name": "count",
544 | "type": "INT",
545 | "links": null,
546 | "tooltip": "Total number of non-empty lines"
547 | },
548 | {
549 | "name": "selected",
550 | "type": "STRING",
551 | "links": null,
552 | "tooltip": "Currently selected line based on select input"
553 | }
554 | ],
555 | "properties": {
556 | "Node name for S&R": "IF_DisplayText"
557 | },
558 | "widgets_values": [
559 | "",
560 | 0,
561 | "D:\\ComfyUI\\output\\memo_video_20241216-193105.mp4"
562 | ]
563 | },
564 | {
565 | "id": 53,
566 | "type": "IF_LoadImagesS",
567 | "pos": [
568 | -524.0703735351562,
569 | 187.803466796875
570 | ],
571 | "size": [
572 | 342.7641906738281,
573 | 730
574 | ],
575 | "flags": {},
576 | "order": 5,
577 | "mode": 0,
578 | "inputs": [],
579 | "outputs": [
580 | {
581 | "name": "images",
582 | "type": "IMAGE",
583 | "links": [
584 | 84
585 | ],
586 | "slot_index": 0,
587 | "shape": 6
588 | },
589 | {
590 | "name": "masks",
591 | "type": "MASK",
592 | "links": [],
593 | "shape": 6
594 | },
595 | {
596 | "name": "image_paths",
597 | "type": "STRING",
598 | "links": null,
599 | "shape": 6
600 | },
601 | {
602 | "name": "filenames",
603 | "type": "STRING",
604 | "links": null,
605 | "shape": 6
606 | },
607 | {
608 | "name": "count_str",
609 | "type": "STRING",
610 | "links": null,
611 | "shape": 6
612 | },
613 | {
614 | "name": "count_int",
615 | "type": "INT",
616 | "links": null,
617 | "shape": 6
618 | }
619 | ],
620 | "properties": {
621 | "Node name for S&R": "IF_LoadImagesS"
622 | },
623 | "widgets_values": [
624 | "thb_00a04b95511409f306f73f6da484ae41.jpg",
625 | "Refresh Previews 🔄",
626 | "C:\\Users\\SOYYO\\Pictures\\people",
627 | 0,
628 | 30,
629 | "100",
630 | true,
631 | 31,
632 | true,
633 | "alphabetical",
634 | "none",
635 | "red",
636 | "image",
637 | "Select Folder 📂",
638 | "Backup Input 💾",
639 | "Restore Input ♻️"
640 | ]
641 | }
642 | ],
643 | "links": [
644 | [
645 | 84,
646 | 53,
647 | 0,
648 | 54,
649 | 0,
650 | "IMAGE"
651 | ],
652 | [
653 | 89,
654 | 54,
655 | 0,
656 | 57,
657 | 0,
658 | "IMAGE"
659 | ],
660 | [
661 | 90,
662 | 57,
663 | 0,
664 | 56,
665 | 0,
666 | "IMAGE"
667 | ],
668 | [
669 | 100,
670 | 32,
671 | 0,
672 | 61,
673 | 2,
674 | "MODEL"
675 | ],
676 | [
677 | 101,
678 | 32,
679 | 1,
680 | 61,
681 | 3,
682 | "MODEL"
683 | ],
684 | [
685 | 102,
686 | 32,
687 | 2,
688 | 61,
689 | 4,
690 | "VAE"
691 | ],
692 | [
693 | 103,
694 | 32,
695 | 3,
696 | 61,
697 | 5,
698 | "IMAGE_PROJ"
699 | ],
700 | [
701 | 104,
702 | 32,
703 | 4,
704 | 61,
705 | 6,
706 | "AUDIO_PROJ"
707 | ],
708 | [
709 | 105,
710 | 32,
711 | 5,
712 | 61,
713 | 7,
714 | "EMOTION_CLASSIFIER"
715 | ],
716 | [
717 | 108,
718 | 61,
719 | 1,
720 | 31,
721 | 0,
722 | "STRING"
723 | ],
724 | [
725 | 111,
726 | 60,
727 | 0,
728 | 61,
729 | 1,
730 | "AUDIO"
731 | ],
732 | [
733 | 112,
734 | 57,
735 | 0,
736 | 61,
737 | 0,
738 | "IMAGE"
739 | ],
740 | [
741 | 114,
742 | 61,
743 | 0,
744 | 67,
745 | 0,
746 | "STRING"
747 | ]
748 | ],
749 | "groups": [
750 | {
751 | "id": 1,
752 | "title": "AVATAR",
753 | "bounding": [
754 | 65.28616333007812,
755 | 85.40703582763672,
756 | 1330.9530029296875,
757 | 813.2778930664062
758 | ],
759 | "color": "#444",
760 | "font_size": 24,
761 | "flags": {}
762 | },
763 | {
764 | "id": 2,
765 | "title": "MEMO",
766 | "bounding": [
767 | -534.0703735351562,
768 | 84.20343780517578,
769 | 589.9998168945312,
770 | 879.60009765625
771 | ],
772 | "color": "#444",
773 | "font_size": 24,
774 | "flags": {}
775 | }
776 | ],
777 | "config": {},
778 | "extra": {
779 | "ds": {
780 | "scale": 0.751314800901579,
781 | "offset": [
782 | 654.7879901456771,
783 | 52.171584759852784
784 | ]
785 | },
786 | "ue_links": []
787 | },
788 | "version": 0.4
789 | }
--------------------------------------------------------------------------------
/workflow/IF_MemoAvatar_IF_Extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 79,
3 | "last_link_id": 206,
4 | "nodes": [
5 | {
6 | "id": 56,
7 | "type": "PreviewImage",
8 | "pos": [
9 | -174.07064819335938,
10 | 707.8036499023438
11 | ],
12 | "size": [
13 | 211.03074645996094,
14 | 246
15 | ],
16 | "flags": {},
17 | "order": 7,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "images",
22 | "type": "IMAGE",
23 | "link": 90
24 | }
25 | ],
26 | "outputs": [],
27 | "properties": {
28 | "Node name for S&R": "PreviewImage"
29 | },
30 | "widgets_values": []
31 | },
32 | {
33 | "id": 50,
34 | "type": "Note",
35 | "pos": [
36 | 446.33880615234375,
37 | 203.38783264160156
38 | ],
39 | "size": [
40 | 331.2768249511719,
41 | 117.3055648803711
42 | ],
43 | "flags": {},
44 | "order": 0,
45 | "mode": 0,
46 | "inputs": [],
47 | "outputs": [],
48 | "properties": {},
49 | "widgets_values": [
50 | "3090 worked with 384 faster\n512 slower \nreducing fps braked the video but needs testing and optimizations \nThis is a wrapper is using diffusers and same code as the Original Repo I only made a few changes it should work on 3090 or maybe even in lesser cards if chosen a smaller res"
51 | ],
52 | "color": "#432",
53 | "bgcolor": "#653"
54 | },
55 | {
56 | "id": 54,
57 | "type": "ImageResizeKJ",
58 | "pos": [
59 | -164.07064819335938,
60 | 157.803466796875
61 | ],
62 | "size": [
63 | 210,
64 | 238
65 | ],
66 | "flags": {
67 | "collapsed": false
68 | },
69 | "order": 5,
70 | "mode": 0,
71 | "inputs": [
72 | {
73 | "name": "image",
74 | "type": "IMAGE",
75 | "link": 84
76 | },
77 | {
78 | "name": "get_image_size",
79 | "type": "IMAGE",
80 | "link": null,
81 | "shape": 7
82 | },
83 | {
84 | "name": "width_input",
85 | "type": "INT",
86 | "link": null,
87 | "widget": {
88 | "name": "width_input"
89 | },
90 | "shape": 7
91 | },
92 | {
93 | "name": "height_input",
94 | "type": "INT",
95 | "link": null,
96 | "widget": {
97 | "name": "height_input"
98 | },
99 | "shape": 7
100 | }
101 | ],
102 | "outputs": [
103 | {
104 | "name": "IMAGE",
105 | "type": "IMAGE",
106 | "links": [
107 | 89
108 | ],
109 | "slot_index": 0
110 | },
111 | {
112 | "name": "width",
113 | "type": "INT",
114 | "links": null
115 | },
116 | {
117 | "name": "height",
118 | "type": "INT",
119 | "links": null
120 | }
121 | ],
122 | "properties": {
123 | "Node name for S&R": "ImageResizeKJ"
124 | },
125 | "widgets_values": [
126 | 1600,
127 | 1600,
128 | "lanczos",
129 | true,
130 | 2,
131 | 0,
132 | 0,
133 | "disabled"
134 | ]
135 | },
136 | {
137 | "id": 59,
138 | "type": "Fast Groups Muter (rgthree)",
139 | "pos": [
140 | 70.64138793945312,
141 | 937.8055419921875
142 | ],
143 | "size": [
144 | 226.8000030517578,
145 | 106
146 | ],
147 | "flags": {},
148 | "order": 1,
149 | "mode": 0,
150 | "inputs": [],
151 | "outputs": [
152 | {
153 | "name": "OPT_CONNECTION",
154 | "type": "*",
155 | "links": null
156 | }
157 | ],
158 | "properties": {
159 | "matchColors": "",
160 | "matchTitle": "",
161 | "showNav": true,
162 | "sort": "position",
163 | "customSortAlphabet": "",
164 | "toggleRestriction": "default"
165 | }
166 | },
167 | {
168 | "id": 67,
169 | "type": "IF_DisplayText",
170 | "pos": [
171 | 858.1091918945312,
172 | 155.067626953125
173 | ],
174 | "size": [
175 | 295.78253173828125,
176 | 228.17208862304688
177 | ],
178 | "flags": {},
179 | "order": 9,
180 | "mode": 0,
181 | "inputs": [
182 | {
183 | "name": "text",
184 | "type": "STRING",
185 | "link": 204,
186 | "widget": {
187 | "name": "text"
188 | }
189 | }
190 | ],
191 | "outputs": [
192 | {
193 | "name": "text",
194 | "type": "STRING",
195 | "links": null,
196 | "tooltip": "Full text content"
197 | },
198 | {
199 | "name": "text_list",
200 | "type": "STRING",
201 | "links": null,
202 | "shape": 6,
203 | "tooltip": "Individual lines as separate outputs"
204 | },
205 | {
206 | "name": "count",
207 | "type": "INT",
208 | "links": null,
209 | "tooltip": "Total number of non-empty lines"
210 | },
211 | {
212 | "name": "selected",
213 | "type": "STRING",
214 | "links": null,
215 | "tooltip": "Currently selected line based on select input"
216 | }
217 | ],
218 | "properties": {
219 | "Node name for S&R": "IF_DisplayText"
220 | },
221 | "widgets_values": [
222 | "",
223 | 0,
224 | "D:\\ComfyUI\\output\\memo_video_20241217-083345.mp4"
225 | ]
226 | },
227 | {
228 | "id": 31,
229 | "type": "IF_DisplayText",
230 | "pos": [
231 | 849.99169921875,
232 | 726.5304565429688
233 | ],
234 | "size": [
235 | 295.78253173828125,
236 | 228.17208862304688
237 | ],
238 | "flags": {},
239 | "order": 11,
240 | "mode": 0,
241 | "inputs": [
242 | {
243 | "name": "text",
244 | "type": "STRING",
245 | "link": 205,
246 | "widget": {
247 | "name": "text"
248 | }
249 | }
250 | ],
251 | "outputs": [
252 | {
253 | "name": "text",
254 | "type": "STRING",
255 | "links": null,
256 | "tooltip": "Full text content"
257 | },
258 | {
259 | "name": "text_list",
260 | "type": "STRING",
261 | "links": null,
262 | "shape": 6,
263 | "tooltip": "Individual lines as separate outputs"
264 | },
265 | {
266 | "name": "count",
267 | "type": "INT",
268 | "links": null,
269 | "tooltip": "Total number of non-empty lines"
270 | },
271 | {
272 | "name": "selected",
273 | "type": "STRING",
274 | "links": null,
275 | "tooltip": "Currently selected line based on select input"
276 | }
277 | ],
278 | "properties": {
279 | "Node name for S&R": "IF_DisplayText"
280 | },
281 | "widgets_values": [
282 | "",
283 | 0,
284 | "✅ Video saved as memo_video_20241217-083345.mp4"
285 | ]
286 | },
287 | {
288 | "id": 57,
289 | "type": "ImageCrop+",
290 | "pos": [
291 | -174.07064819335938,
292 | 457.80340576171875
293 | ],
294 | "size": [
295 | 210,
296 | 194
297 | ],
298 | "flags": {},
299 | "order": 6,
300 | "mode": 0,
301 | "inputs": [
302 | {
303 | "name": "image",
304 | "type": "IMAGE",
305 | "link": 89
306 | }
307 | ],
308 | "outputs": [
309 | {
310 | "name": "IMAGE",
311 | "type": "IMAGE",
312 | "links": [
313 | 90,
314 | 198
315 | ],
316 | "slot_index": 0
317 | },
318 | {
319 | "name": "x",
320 | "type": "INT",
321 | "links": null
322 | },
323 | {
324 | "name": "y",
325 | "type": "INT",
326 | "links": null
327 | }
328 | ],
329 | "properties": {
330 | "Node name for S&R": "ImageCrop+"
331 | },
332 | "widgets_values": [
333 | 1504,
334 | 1504,
335 | "center",
336 | -2,
337 | -41
338 | ]
339 | },
340 | {
341 | "id": 60,
342 | "type": "LoadAudio",
343 | "pos": [
344 | 70.73645782470703,
345 | 259.4642639160156
346 | ],
347 | "size": [
348 | 315,
349 | 124
350 | ],
351 | "flags": {},
352 | "order": 2,
353 | "mode": 0,
354 | "inputs": [],
355 | "outputs": [
356 | {
357 | "name": "AUDIO",
358 | "type": "AUDIO",
359 | "links": [
360 | 199
361 | ],
362 | "slot_index": 0
363 | }
364 | ],
365 | "properties": {
366 | "Node name for S&R": "LoadAudio"
367 | },
368 | "widgets_values": [
369 | "candy.wav",
370 | null,
371 | ""
372 | ]
373 | },
374 | {
375 | "id": 79,
376 | "type": "IF_MemoAvatar",
377 | "pos": [
378 | 415.24700927734375,
379 | 423.51800537109375
380 | ],
381 | "size": [
382 | 400,
383 | 390
384 | ],
385 | "flags": {},
386 | "order": 8,
387 | "mode": 0,
388 | "inputs": [
389 | {
390 | "name": "image",
391 | "type": "IMAGE",
392 | "link": 198
393 | },
394 | {
395 | "name": "audio",
396 | "type": "AUDIO",
397 | "link": 199
398 | },
399 | {
400 | "name": "reference_net",
401 | "type": "MODEL",
402 | "link": 196
403 | },
404 | {
405 | "name": "diffusion_net",
406 | "type": "MODEL",
407 | "link": 197
408 | },
409 | {
410 | "name": "vae",
411 | "type": "VAE",
412 | "link": 200
413 | },
414 | {
415 | "name": "image_proj",
416 | "type": "IMAGE_PROJ",
417 | "link": 201
418 | },
419 | {
420 | "name": "audio_proj",
421 | "type": "AUDIO_PROJ",
422 | "link": 202
423 | },
424 | {
425 | "name": "emotion_classifier",
426 | "type": "EMOTION_CLASSIFIER",
427 | "link": 203
428 | }
429 | ],
430 | "outputs": [
431 | {
432 | "name": "video_path",
433 | "type": "STRING",
434 | "links": [
435 | 204,
436 | 206
437 | ],
438 | "slot_index": 0
439 | },
440 | {
441 | "name": "status",
442 | "type": "STRING",
443 | "links": [
444 | 205
445 | ],
446 | "slot_index": 1
447 | }
448 | ],
449 | "properties": {
450 | "Node name for S&R": "IF_MemoAvatar"
451 | },
452 | "widgets_values": [
453 | 512,
454 | 16,
455 | 30,
456 | 20,
457 | 3.5,
458 | 2047,
459 | "randomize",
460 | "memo_video",
461 | {
462 | "serialize": false,
463 | "size": [
464 | 256,
465 | 256
466 | ]
467 | }
468 | ]
469 | },
470 | {
471 | "id": 68,
472 | "type": "VHS_LoadVideoPath",
473 | "pos": [
474 | 880.303466796875,
475 | 424.2923278808594
476 | ],
477 | "size": [
478 | 291.62481689453125,
479 | 238
480 | ],
481 | "flags": {},
482 | "order": 10,
483 | "mode": 0,
484 | "inputs": [
485 | {
486 | "name": "meta_batch",
487 | "type": "VHS_BatchManager",
488 | "link": null,
489 | "shape": 7
490 | },
491 | {
492 | "name": "vae",
493 | "type": "VAE",
494 | "link": null,
495 | "shape": 7
496 | },
497 | {
498 | "name": "video",
499 | "type": "STRING",
500 | "link": 206,
501 | "widget": {
502 | "name": "video"
503 | }
504 | }
505 | ],
506 | "outputs": [
507 | {
508 | "name": "IMAGE",
509 | "type": "IMAGE",
510 | "links": null
511 | },
512 | {
513 | "name": "frame_count",
514 | "type": "INT",
515 | "links": null
516 | },
517 | {
518 | "name": "audio",
519 | "type": "AUDIO",
520 | "links": null
521 | },
522 | {
523 | "name": "video_info",
524 | "type": "VHS_VIDEOINFO",
525 | "links": null
526 | }
527 | ],
528 | "properties": {
529 | "Node name for S&R": "VHS_LoadVideoPath"
530 | },
531 | "widgets_values": {
532 | "video": "",
533 | "force_rate": 0,
534 | "force_size": "Disabled",
535 | "custom_width": 512,
536 | "custom_height": 512,
537 | "frame_load_cap": 0,
538 | "skip_first_frames": 0,
539 | "select_every_nth": 1,
540 | "videopreview": {
541 | "hidden": false,
542 | "paused": false,
543 | "params": {
544 | "force_rate": 0,
545 | "frame_load_cap": 0,
546 | "skip_first_frames": 0,
547 | "select_every_nth": 1
548 | },
549 | "muted": false
550 | }
551 | }
552 | },
553 | {
554 | "id": 74,
555 | "type": "IF_MemoCheckpointLoader",
556 | "pos": [
557 | 80.47496032714844,
558 | 463.7469482421875
559 | ],
560 | "size": [
561 | 315,
562 | 158
563 | ],
564 | "flags": {},
565 | "order": 3,
566 | "mode": 0,
567 | "inputs": [],
568 | "outputs": [
569 | {
570 | "name": "reference_net",
571 | "type": "MODEL",
572 | "links": [
573 | 196
574 | ],
575 | "slot_index": 0
576 | },
577 | {
578 | "name": "diffusion_net",
579 | "type": "MODEL",
580 | "links": [
581 | 197
582 | ],
583 | "slot_index": 1
584 | },
585 | {
586 | "name": "vae",
587 | "type": "VAE",
588 | "links": [
589 | 200
590 | ],
591 | "slot_index": 2
592 | },
593 | {
594 | "name": "image_proj",
595 | "type": "IMAGE_PROJ",
596 | "links": [
597 | 201
598 | ],
599 | "slot_index": 3
600 | },
601 | {
602 | "name": "audio_proj",
603 | "type": "AUDIO_PROJ",
604 | "links": [
605 | 202
606 | ],
607 | "slot_index": 4
608 | },
609 | {
610 | "name": "emotion_classifier",
611 | "type": "EMOTION_CLASSIFIER",
612 | "links": [
613 | 203
614 | ],
615 | "slot_index": 5
616 | }
617 | ],
618 | "properties": {
619 | "Node name for S&R": "IF_MemoCheckpointLoader"
620 | },
621 | "widgets_values": [
622 | true
623 | ]
624 | },
625 | {
626 | "id": 53,
627 | "type": "IF_LoadImagesS",
628 | "pos": [
629 | -524.0703735351562,
630 | 187.803466796875
631 | ],
632 | "size": [
633 | 342.7641906738281,
634 | 730
635 | ],
636 | "flags": {},
637 | "order": 4,
638 | "mode": 0,
639 | "inputs": [],
640 | "outputs": [
641 | {
642 | "name": "images",
643 | "type": "IMAGE",
644 | "links": [
645 | 84
646 | ],
647 | "slot_index": 0,
648 | "shape": 6
649 | },
650 | {
651 | "name": "masks",
652 | "type": "MASK",
653 | "links": [],
654 | "shape": 6
655 | },
656 | {
657 | "name": "image_paths",
658 | "type": "STRING",
659 | "links": null,
660 | "shape": 6
661 | },
662 | {
663 | "name": "filenames",
664 | "type": "STRING",
665 | "links": null,
666 | "shape": 6
667 | },
668 | {
669 | "name": "count_str",
670 | "type": "STRING",
671 | "links": null,
672 | "shape": 6
673 | },
674 | {
675 | "name": "count_int",
676 | "type": "INT",
677 | "links": null,
678 | "shape": 6
679 | }
680 | ],
681 | "properties": {
682 | "Node name for S&R": "IF_LoadImagesS"
683 | },
684 | "widgets_values": [
685 | "thb_4bc166bd2e2e609360a884982b812965.jpg",
686 | "Refresh Previews 🔄",
687 | "C:\\Users\\SOYYO\\Pictures\\people",
688 | 0,
689 | 48,
690 | "100",
691 | true,
692 | 48,
693 | true,
694 | "alphabetical",
695 | "none",
696 | "red",
697 | "image",
698 | "Select Folder 📂",
699 | "Backup Input 💾",
700 | "Restore Input ♻️"
701 | ]
702 | }
703 | ],
704 | "links": [
705 | [
706 | 84,
707 | 53,
708 | 0,
709 | 54,
710 | 0,
711 | "IMAGE"
712 | ],
713 | [
714 | 89,
715 | 54,
716 | 0,
717 | 57,
718 | 0,
719 | "IMAGE"
720 | ],
721 | [
722 | 90,
723 | 57,
724 | 0,
725 | 56,
726 | 0,
727 | "IMAGE"
728 | ],
729 | [
730 | 196,
731 | 74,
732 | 0,
733 | 79,
734 | 2,
735 | "MODEL"
736 | ],
737 | [
738 | 197,
739 | 74,
740 | 1,
741 | 79,
742 | 3,
743 | "MODEL"
744 | ],
745 | [
746 | 198,
747 | 57,
748 | 0,
749 | 79,
750 | 0,
751 | "IMAGE"
752 | ],
753 | [
754 | 199,
755 | 60,
756 | 0,
757 | 79,
758 | 1,
759 | "AUDIO"
760 | ],
761 | [
762 | 200,
763 | 74,
764 | 2,
765 | 79,
766 | 4,
767 | "VAE"
768 | ],
769 | [
770 | 201,
771 | 74,
772 | 3,
773 | 79,
774 | 5,
775 | "IMAGE_PROJ"
776 | ],
777 | [
778 | 202,
779 | 74,
780 | 4,
781 | 79,
782 | 6,
783 | "AUDIO_PROJ"
784 | ],
785 | [
786 | 203,
787 | 74,
788 | 5,
789 | 79,
790 | 7,
791 | "EMOTION_CLASSIFIER"
792 | ],
793 | [
794 | 204,
795 | 79,
796 | 0,
797 | 67,
798 | 0,
799 | "STRING"
800 | ],
801 | [
802 | 205,
803 | 79,
804 | 1,
805 | 31,
806 | 0,
807 | "STRING"
808 | ],
809 | [
810 | 206,
811 | 79,
812 | 0,
813 | 68,
814 | 2,
815 | "STRING"
816 | ]
817 | ],
818 | "groups": [
819 | {
820 | "id": 1,
821 | "title": "AVATAR",
822 | "bounding": [
823 | 65.28616333007812,
824 | 85.40703582763672,
825 | 1330.9530029296875,
826 | 813.2778930664062
827 | ],
828 | "color": "#444",
829 | "font_size": 24,
830 | "flags": {}
831 | },
832 | {
833 | "id": 2,
834 | "title": "MEMO",
835 | "bounding": [
836 | -534.0703735351562,
837 | 84.20343780517578,
838 | 589.9998168945312,
839 | 879.60009765625
840 | ],
841 | "color": "#444",
842 | "font_size": 24,
843 | "flags": {}
844 | }
845 | ],
846 | "config": {},
847 | "extra": {
848 | "ds": {
849 | "scale": 1,
850 | "offset": [
851 | 566.6587908244418,
852 | -17.309013215858414
853 | ]
854 | },
855 | "ue_links": []
856 | },
857 | "version": 0.4
858 | }
--------------------------------------------------------------------------------
/workflow/IF_MemoAvatar_simple.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 82,
3 | "last_link_id": 209,
4 | "nodes": [
5 | {
6 | "id": 56,
7 | "type": "PreviewImage",
8 | "pos": [
9 | -174.07064819335938,
10 | 707.8036499023438
11 | ],
12 | "size": [
13 | 211.03074645996094,
14 | 246
15 | ],
16 | "flags": {},
17 | "order": 7,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "images",
22 | "type": "IMAGE",
23 | "link": 90
24 | }
25 | ],
26 | "outputs": [],
27 | "properties": {
28 | "Node name for S&R": "PreviewImage"
29 | },
30 | "widgets_values": []
31 | },
32 | {
33 | "id": 50,
34 | "type": "Note",
35 | "pos": [
36 | 446.33880615234375,
37 | 203.38783264160156
38 | ],
39 | "size": [
40 | 331.2768249511719,
41 | 117.3055648803711
42 | ],
43 | "flags": {},
44 | "order": 0,
45 | "mode": 0,
46 | "inputs": [],
47 | "outputs": [],
48 | "properties": {},
49 | "widgets_values": [
50 | "3090 worked with 384 faster\n512 slower \nreducing fps braked the video but needs testing and optimizations \nThis is a wrapper is using diffusers and same code as the Original Repo I only made a few changes it should work on 3090 or maybe even in lesser cards if chosen a smaller res"
51 | ],
52 | "color": "#432",
53 | "bgcolor": "#653"
54 | },
55 | {
56 | "id": 59,
57 | "type": "Fast Groups Muter (rgthree)",
58 | "pos": [
59 | 70.64138793945312,
60 | 937.8055419921875
61 | ],
62 | "size": [
63 | 226.8000030517578,
64 | 106
65 | ],
66 | "flags": {},
67 | "order": 1,
68 | "mode": 0,
69 | "inputs": [],
70 | "outputs": [
71 | {
72 | "name": "OPT_CONNECTION",
73 | "type": "*",
74 | "links": null
75 | }
76 | ],
77 | "properties": {
78 | "matchColors": "",
79 | "matchTitle": "",
80 | "showNav": true,
81 | "sort": "position",
82 | "customSortAlphabet": "",
83 | "toggleRestriction": "default"
84 | }
85 | },
86 | {
87 | "id": 57,
88 | "type": "ImageCrop+",
89 | "pos": [
90 | -174.07064819335938,
91 | 457.80340576171875
92 | ],
93 | "size": [
94 | 210,
95 | 194
96 | ],
97 | "flags": {},
98 | "order": 6,
99 | "mode": 0,
100 | "inputs": [
101 | {
102 | "name": "image",
103 | "type": "IMAGE",
104 | "link": 89
105 | }
106 | ],
107 | "outputs": [
108 | {
109 | "name": "IMAGE",
110 | "type": "IMAGE",
111 | "links": [
112 | 90,
113 | 198
114 | ],
115 | "slot_index": 0
116 | },
117 | {
118 | "name": "x",
119 | "type": "INT",
120 | "links": null
121 | },
122 | {
123 | "name": "y",
124 | "type": "INT",
125 | "links": null
126 | }
127 | ],
128 | "properties": {
129 | "Node name for S&R": "ImageCrop+"
130 | },
131 | "widgets_values": [
132 | 1504,
133 | 1504,
134 | "center",
135 | -2,
136 | -41
137 | ]
138 | },
139 | {
140 | "id": 60,
141 | "type": "LoadAudio",
142 | "pos": [
143 | 70.73645782470703,
144 | 259.4642639160156
145 | ],
146 | "size": [
147 | 315,
148 | 124
149 | ],
150 | "flags": {},
151 | "order": 2,
152 | "mode": 0,
153 | "inputs": [],
154 | "outputs": [
155 | {
156 | "name": "AUDIO",
157 | "type": "AUDIO",
158 | "links": [
159 | 199
160 | ],
161 | "slot_index": 0
162 | }
163 | ],
164 | "properties": {
165 | "Node name for S&R": "LoadAudio"
166 | },
167 | "widgets_values": [
168 | "candy.wav",
169 | null,
170 | ""
171 | ]
172 | },
173 | {
174 | "id": 68,
175 | "type": "VHS_LoadVideoPath",
176 | "pos": [
177 | 880.303466796875,
178 | 424.2923278808594
179 | ],
180 | "size": [
181 | 291.62481689453125,
182 | 238
183 | ],
184 | "flags": {},
185 | "order": 9,
186 | "mode": 0,
187 | "inputs": [
188 | {
189 | "name": "meta_batch",
190 | "type": "VHS_BatchManager",
191 | "link": null,
192 | "shape": 7
193 | },
194 | {
195 | "name": "vae",
196 | "type": "VAE",
197 | "link": null,
198 | "shape": 7
199 | },
200 | {
201 | "name": "video",
202 | "type": "STRING",
203 | "link": 206,
204 | "widget": {
205 | "name": "video"
206 | }
207 | }
208 | ],
209 | "outputs": [
210 | {
211 | "name": "IMAGE",
212 | "type": "IMAGE",
213 | "links": null
214 | },
215 | {
216 | "name": "frame_count",
217 | "type": "INT",
218 | "links": null
219 | },
220 | {
221 | "name": "audio",
222 | "type": "AUDIO",
223 | "links": null
224 | },
225 | {
226 | "name": "video_info",
227 | "type": "VHS_VIDEOINFO",
228 | "links": null
229 | }
230 | ],
231 | "properties": {
232 | "Node name for S&R": "VHS_LoadVideoPath"
233 | },
234 | "widgets_values": {
235 | "video": "",
236 | "force_rate": 0,
237 | "force_size": "Disabled",
238 | "custom_width": 512,
239 | "custom_height": 512,
240 | "frame_load_cap": 0,
241 | "skip_first_frames": 0,
242 | "select_every_nth": 1,
243 | "videopreview": {
244 | "hidden": false,
245 | "paused": false,
246 | "params": {
247 | "force_rate": 0,
248 | "frame_load_cap": 0,
249 | "skip_first_frames": 0,
250 | "select_every_nth": 1
251 | },
252 | "muted": false
253 | }
254 | }
255 | },
256 | {
257 | "id": 74,
258 | "type": "IF_MemoCheckpointLoader",
259 | "pos": [
260 | 80.47496032714844,
261 | 463.7469482421875
262 | ],
263 | "size": [
264 | 315,
265 | 158
266 | ],
267 | "flags": {},
268 | "order": 3,
269 | "mode": 0,
270 | "inputs": [],
271 | "outputs": [
272 | {
273 | "name": "reference_net",
274 | "type": "MODEL",
275 | "links": [
276 | 196
277 | ],
278 | "slot_index": 0
279 | },
280 | {
281 | "name": "diffusion_net",
282 | "type": "MODEL",
283 | "links": [
284 | 197
285 | ],
286 | "slot_index": 1
287 | },
288 | {
289 | "name": "vae",
290 | "type": "VAE",
291 | "links": [
292 | 200
293 | ],
294 | "slot_index": 2
295 | },
296 | {
297 | "name": "image_proj",
298 | "type": "IMAGE_PROJ",
299 | "links": [
300 | 201
301 | ],
302 | "slot_index": 3
303 | },
304 | {
305 | "name": "audio_proj",
306 | "type": "AUDIO_PROJ",
307 | "links": [
308 | 202
309 | ],
310 | "slot_index": 4
311 | },
312 | {
313 | "name": "emotion_classifier",
314 | "type": "EMOTION_CLASSIFIER",
315 | "links": [
316 | 203
317 | ],
318 | "slot_index": 5
319 | }
320 | ],
321 | "properties": {
322 | "Node name for S&R": "IF_MemoCheckpointLoader"
323 | },
324 | "widgets_values": [
325 | true
326 | ]
327 | },
328 | {
329 | "id": 54,
330 | "type": "ImageResizeKJ",
331 | "pos": [
332 | -164.07064819335938,
333 | 157.803466796875
334 | ],
335 | "size": [
336 | 210,
337 | 238
338 | ],
339 | "flags": {
340 | "collapsed": false
341 | },
342 | "order": 5,
343 | "mode": 0,
344 | "inputs": [
345 | {
346 | "name": "image",
347 | "type": "IMAGE",
348 | "link": 207
349 | },
350 | {
351 | "name": "get_image_size",
352 | "type": "IMAGE",
353 | "link": null,
354 | "shape": 7
355 | },
356 | {
357 | "name": "width_input",
358 | "type": "INT",
359 | "link": null,
360 | "widget": {
361 | "name": "width_input"
362 | },
363 | "shape": 7
364 | },
365 | {
366 | "name": "height_input",
367 | "type": "INT",
368 | "link": null,
369 | "widget": {
370 | "name": "height_input"
371 | },
372 | "shape": 7
373 | }
374 | ],
375 | "outputs": [
376 | {
377 | "name": "IMAGE",
378 | "type": "IMAGE",
379 | "links": [
380 | 89
381 | ],
382 | "slot_index": 0
383 | },
384 | {
385 | "name": "width",
386 | "type": "INT",
387 | "links": null
388 | },
389 | {
390 | "name": "height",
391 | "type": "INT",
392 | "links": null
393 | }
394 | ],
395 | "properties": {
396 | "Node name for S&R": "ImageResizeKJ"
397 | },
398 | "widgets_values": [
399 | 1600,
400 | 1600,
401 | "lanczos",
402 | true,
403 | 2,
404 | 0,
405 | 0,
406 | "disabled"
407 | ]
408 | },
409 | {
410 | "id": 80,
411 | "type": "LoadImage",
412 | "pos": [
413 | -498.6588134765625,
414 | 202.10902404785156
415 | ],
416 | "size": [
417 | 315,
418 | 314
419 | ],
420 | "flags": {},
421 | "order": 4,
422 | "mode": 0,
423 | "inputs": [],
424 | "outputs": [
425 | {
426 | "name": "IMAGE",
427 | "type": "IMAGE",
428 | "links": [
429 | 207
430 | ]
431 | },
432 | {
433 | "name": "MASK",
434 | "type": "MASK",
435 | "links": null
436 | }
437 | ],
438 | "properties": {
439 | "Node name for S&R": "LoadImage"
440 | },
441 | "widgets_values": [
442 | "candy@2x.png",
443 | "image"
444 | ]
445 | },
446 | {
447 | "id": 81,
448 | "type": "ShowText|pysssss",
449 | "pos": [
450 | 864.541259765625,
451 | 733.3090209960938
452 | ],
453 | "size": [
454 | 315,
455 | 58
456 | ],
457 | "flags": {},
458 | "order": 11,
459 | "mode": 0,
460 | "inputs": [
461 | {
462 | "name": "text",
463 | "type": "STRING",
464 | "link": 208,
465 | "widget": {
466 | "name": "text"
467 | }
468 | }
469 | ],
470 | "outputs": [
471 | {
472 | "name": "STRING",
473 | "type": "STRING",
474 | "links": null,
475 | "shape": 6
476 | }
477 | ],
478 | "properties": {
479 | "Node name for S&R": "ShowText|pysssss"
480 | },
481 | "widgets_values": [
482 | ""
483 | ]
484 | },
485 | {
486 | "id": 82,
487 | "type": "ShowText|pysssss",
488 | "pos": [
489 | 866.1412353515625,
490 | 292.5090026855469
491 | ],
492 | "size": [
493 | 315,
494 | 58
495 | ],
496 | "flags": {},
497 | "order": 10,
498 | "mode": 0,
499 | "inputs": [
500 | {
501 | "name": "text",
502 | "type": "STRING",
503 | "link": 209,
504 | "widget": {
505 | "name": "text"
506 | }
507 | }
508 | ],
509 | "outputs": [
510 | {
511 | "name": "STRING",
512 | "type": "STRING",
513 | "links": null,
514 | "shape": 6
515 | }
516 | ],
517 | "properties": {
518 | "Node name for S&R": "ShowText|pysssss"
519 | },
520 | "widgets_values": [
521 | ""
522 | ]
523 | },
524 | {
525 | "id": 79,
526 | "type": "IF_MemoAvatar",
527 | "pos": [
528 | 415.24700927734375,
529 | 423.51800537109375
530 | ],
531 | "size": [
532 | 400,
533 | 390
534 | ],
535 | "flags": {},
536 | "order": 8,
537 | "mode": 0,
538 | "inputs": [
539 | {
540 | "name": "image",
541 | "type": "IMAGE",
542 | "link": 198
543 | },
544 | {
545 | "name": "audio",
546 | "type": "AUDIO",
547 | "link": 199
548 | },
549 | {
550 | "name": "reference_net",
551 | "type": "MODEL",
552 | "link": 196
553 | },
554 | {
555 | "name": "diffusion_net",
556 | "type": "MODEL",
557 | "link": 197
558 | },
559 | {
560 | "name": "vae",
561 | "type": "VAE",
562 | "link": 200
563 | },
564 | {
565 | "name": "image_proj",
566 | "type": "IMAGE_PROJ",
567 | "link": 201
568 | },
569 | {
570 | "name": "audio_proj",
571 | "type": "AUDIO_PROJ",
572 | "link": 202
573 | },
574 | {
575 | "name": "emotion_classifier",
576 | "type": "EMOTION_CLASSIFIER",
577 | "link": 203
578 | }
579 | ],
580 | "outputs": [
581 | {
582 | "name": "video_path",
583 | "type": "STRING",
584 | "links": [
585 | 206,
586 | 209
587 | ],
588 | "slot_index": 0
589 | },
590 | {
591 | "name": "status",
592 | "type": "STRING",
593 | "links": [
594 | 208
595 | ],
596 | "slot_index": 1
597 | }
598 | ],
599 | "properties": {
600 | "Node name for S&R": "IF_MemoAvatar"
601 | },
602 | "widgets_values": [
603 | 384,
604 | 16,
605 | 30,
606 | 20,
607 | 3.5,
608 | 2047,
609 | "randomize",
610 | "memo_video",
611 | {
612 | "serialize": false,
613 | "size": [
614 | 256,
615 | 256
616 | ]
617 | }
618 | ]
619 | }
620 | ],
621 | "links": [
622 | [
623 | 89,
624 | 54,
625 | 0,
626 | 57,
627 | 0,
628 | "IMAGE"
629 | ],
630 | [
631 | 90,
632 | 57,
633 | 0,
634 | 56,
635 | 0,
636 | "IMAGE"
637 | ],
638 | [
639 | 196,
640 | 74,
641 | 0,
642 | 79,
643 | 2,
644 | "MODEL"
645 | ],
646 | [
647 | 197,
648 | 74,
649 | 1,
650 | 79,
651 | 3,
652 | "MODEL"
653 | ],
654 | [
655 | 198,
656 | 57,
657 | 0,
658 | 79,
659 | 0,
660 | "IMAGE"
661 | ],
662 | [
663 | 199,
664 | 60,
665 | 0,
666 | 79,
667 | 1,
668 | "AUDIO"
669 | ],
670 | [
671 | 200,
672 | 74,
673 | 2,
674 | 79,
675 | 4,
676 | "VAE"
677 | ],
678 | [
679 | 201,
680 | 74,
681 | 3,
682 | 79,
683 | 5,
684 | "IMAGE_PROJ"
685 | ],
686 | [
687 | 202,
688 | 74,
689 | 4,
690 | 79,
691 | 6,
692 | "AUDIO_PROJ"
693 | ],
694 | [
695 | 203,
696 | 74,
697 | 5,
698 | 79,
699 | 7,
700 | "EMOTION_CLASSIFIER"
701 | ],
702 | [
703 | 206,
704 | 79,
705 | 0,
706 | 68,
707 | 2,
708 | "STRING"
709 | ],
710 | [
711 | 207,
712 | 80,
713 | 0,
714 | 54,
715 | 0,
716 | "IMAGE"
717 | ],
718 | [
719 | 208,
720 | 79,
721 | 1,
722 | 81,
723 | 0,
724 | "STRING"
725 | ],
726 | [
727 | 209,
728 | 79,
729 | 0,
730 | 82,
731 | 0,
732 | "STRING"
733 | ]
734 | ],
735 | "groups": [
736 | {
737 | "id": 1,
738 | "title": "AVATAR",
739 | "bounding": [
740 | 65.28616333007812,
741 | 85.40703582763672,
742 | 1330.9530029296875,
743 | 813.2778930664062
744 | ],
745 | "color": "#444",
746 | "font_size": 24,
747 | "flags": {}
748 | },
749 | {
750 | "id": 2,
751 | "title": "MEMO",
752 | "bounding": [
753 | -534.0703735351562,
754 | 84.20343780517578,
755 | 589.9998168945312,
756 | 879.60009765625
757 | ],
758 | "color": "#444",
759 | "font_size": 24,
760 | "flags": {}
761 | }
762 | ],
763 | "config": {},
764 | "extra": {
765 | "ds": {
766 | "scale": 1,
767 | "offset": [
768 | 565.0588152385043,
769 | -17.309013215858414
770 | ]
771 | },
772 | "ue_links": []
773 | },
774 | "version": 0.4
775 | }
--------------------------------------------------------------------------------