├── .gitignore ├── .vscode └── launch.json ├── EmoDataset.md ├── EmoDataset.py ├── README.md ├── configs ├── inference │ └── stage1-base.yaml └── training │ ├── stage1-base.yaml │ └── stage2-hr.yaml ├── data ├── celebvhq_info.json ├── driving_video.json └── overfit.json ├── diagram.jpeg ├── draw_warps.py ├── inference.py ├── junk ├── -1eKufUP5XQ_4.mp4 ├── -2KGPYEFnsU_11.mp4 ├── -2KGPYEFnsU_8.mp4 ├── M2Ohb0FAaJU_1.mp4 └── download.torrent ├── metrics.py ├── model.py ├── mysixdrepnet.py ├── output_images ├── driving_frame_0.png ├── driving_frame_1.png ├── driving_frame_star_0.png ├── driving_frame_star_1.png ├── foreground_mask_0.png ├── masked_predicted_image_0.png ├── output_frame_0.png ├── output_frame_1.png ├── source_frame_0.png ├── source_frame_1.png ├── source_frame_star_0.png ├── source_frame_star_1.png └── warped_driving_frame_0.png ├── reference ├── CVPR2022-DaGAN-warpgenerator.txt ├── CosFace.txt ├── CycleGAN.py ├── DPE.txt ├── EffectiveDeepNetworkHeadPose.tx ├── G2d.png ├── cycleGAN.txt ├── flowfields.png ├── google_scholar_profile_results_data.json ├── highres.txt ├── megaportait-samsung.txt ├── megaportrait-network.png ├── megaportrait-student.png ├── meta-emoportraits.txt ├── oneshotfreeview.txt ├── references.txt ├── resnetblocks.png ├── rome.txt ├── rome_all.py ├── talkingguasian.txt ├── test.py ├── warpfield.png ├── warpgenerators │ ├── bilayer-model.py │ ├── headGAN.py │ ├── headGAN.txt │ ├── latent-pose-reenactment.py │ ├── latent-pose-reenactment.txt │ └── warpgenerators.txt ├── x3.png └── x4.png ├── requirements.txt ├── resnet.py ├── resnet50.py ├── rome_losses.py ├── test.py ├── train.py ├── train_highres.py ├── train_student.py └── warp.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | models/gaze_model_pytorch_vgg16_prl_mpii_allsubjects1.model 4 | *.zip 5 | *.pt 6 | *.dat 7 | *.pth 8 | # output_images/*.* 9 | *.png 10 | junk/-2KGPYEFnsU_11_nobg.mp4 11 | *.png 12 | *.png 13 | *.npz 14 | *.0 15 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: Current File", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "train.py", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /EmoDataset.md: -------------------------------------------------------------------------------- 1 | ## EMODataset Class Summary 2 | 3 | ### Overview 4 | The `EMODataset` class is a PyTorch dataset for processing and augmenting video frames, with functionalities to remove backgrounds, warp and crop faces, and save/load processed frames efficiently. The class is designed to handle large video datasets and includes methods to streamline the preprocessing pipeline. 5 | 6 | ### Dependencies 7 | The class relies on the following libraries: 8 | - `moviepy.editor`: Video editing and processing. 9 | - `PIL.Image`: Image processing. 10 | - `torch`: PyTorch for tensor operations and model support. 11 | - `torchvision.transforms`: Image transformations. 12 | - `decord`: Efficient video reading. 13 | - `rembg`: Background removal. 14 | - `face_recognition`: Face detection. 15 | - `skimage.transform`: Image warping. 16 | - `cv2`: Video writing with OpenCV. 17 | - `numpy`: Array operations. 18 | - `io`, `os`, `json`, `Path`, `subprocess`, `tqdm`: Standard libraries for file handling, I/O operations, and progress visualization. 19 | 20 | ### Initialization 21 | The `__init__` method sets up the dataset with various parameters: 22 | - `use_gpu`, `sample_rate`, `n_sample_frames`, `width`, `height`, `img_scale`, `img_ratio`, `video_dir`, `drop_ratio`, `json_file`, `stage`, `transform`, `remove_background`, `use_greenscreen`, `apply_crop_warping` 23 | - Loads video metadata from the provided JSON file. 24 | - Initializes decord for video reading with PyTorch tensor output. 25 | 26 | ### Methods 27 | 28 | #### `__len__` 29 | Returns the length of the dataset, determined by the number of video IDs. 30 | 31 | #### `warp_and_crop_face` 32 | Processes an image tensor to detect, warp, and crop the face region: 33 | - Converts tensor to PIL image. 34 | - Removes background. 35 | - Detects face locations. 36 | - Crops the face region. 37 | - Optionally applies thin-plate-spline warping. 38 | - Converts the processed image back to a tensor and returns it. 39 | 40 | #### `load_and_process_video` 41 | Loads and processes video frames: 42 | - Checks if processed tensor file exists; if so, loads tensors. 43 | - If not, processes video frames, applies augmentation, and saves frames as PNG images and tensors. 44 | - Saves processed tensors as compressed numpy arrays for efficient loading. 45 | 46 | #### `augmentation` 47 | Applies transformations and optional background removal to the provided images: 48 | - Supports both single images and lists of images. 49 | - Returns transformed tensors. 50 | 51 | #### `remove_bg` 52 | Removes the background from the provided image using `rembg`: 53 | - Optionally applies a green screen background. 54 | - Converts image to RGB format and returns it. 55 | 56 | #### `save_video` 57 | Saves a list of frames as a video file: 58 | - Uses OpenCV to write frames to a video file. 59 | 60 | #### `process_video` 61 | Processes all frames of a video: 62 | - Uses the `process_video_frames` method to process frames. 63 | 64 | #### `process_video_frames` 65 | Processes frames of a video using decord: 66 | - Reads frames using decord and applies augmentation. 67 | - Returns processed frames. 68 | 69 | #### `__getitem__` 70 | Returns a sample from the dataset: 71 | - Loads and processes source and driving videos. 72 | - Returns a dictionary containing video IDs and frames. 73 | 74 | ### Usage 75 | To use the `EMODataset` class: 76 | 1. Initialize the dataset with appropriate parameters. 77 | 2. Use PyTorch DataLoader to iterate over the dataset and retrieve samples. 78 | 3. Process the frames as needed for training or inference in a machine learning model. 79 | 80 | ### Example 81 | ```python 82 | from torchvision import transforms 83 | 84 | transform = transforms.Compose([ 85 | transforms.Resize((512, 512)), 86 | transforms.ToTensor(), 87 | ]) 88 | 89 | dataset = EMODataset( 90 | use_gpu=False, 91 | sample_rate=5, 92 | n_sample_frames=16, 93 | width=512, 94 | height=512, 95 | img_scale=(0.9, 1.0), 96 | video_dir="path/to/videos", 97 | json_file="path/to/metadata.json", 98 | transform=transform, 99 | remove_background=True, 100 | use_greenscreen=False, 101 | apply_crop_warping=True 102 | ) 103 | 104 | for sample in dataset: 105 | print(sample) 106 | ``` 107 | 108 | This class provides a comprehensive pipeline for processing video data, making it suitable for tasks such as training deep learning models on video datasets. -------------------------------------------------------------------------------- /EmoDataset.py: -------------------------------------------------------------------------------- 1 | from moviepy.editor import VideoFileClip, ImageSequenceClip 2 | from PIL import Image 3 | import torch 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import Dataset 6 | import json 7 | import os 8 | from typing import List, Tuple, Dict, Any 9 | from decord import VideoReader, cpu 10 | from rembg import remove 11 | import io 12 | import numpy as np 13 | import decord 14 | import subprocess 15 | from tqdm import tqdm 16 | import cv2 17 | from pathlib import Path 18 | from torchvision.transforms.functional import to_pil_image, to_tensor 19 | import random 20 | # face warp 21 | from skimage.transform import PiecewiseAffineTransform, warp 22 | import face_recognition 23 | 24 | class EMODataset(Dataset): 25 | def __init__(self, use_gpu: False, sample_rate: int, n_sample_frames: int, width: int, height: int, img_scale: Tuple[float, float], img_ratio: Tuple[float, float] = (0.9, 1.0), video_dir: str = ".", drop_ratio: float = 0.1, json_file: str = "", stage: str = 'stage1', transform: transforms.Compose = None, remove_background=False, use_greenscreen=False, apply_crop_warping=False): 26 | self.sample_rate = sample_rate 27 | self.n_sample_frames = n_sample_frames 28 | self.width = width 29 | self.height = height 30 | self.img_scale = img_scale 31 | self.img_ratio = img_ratio 32 | self.video_dir = video_dir 33 | self.transform = transform 34 | self.stage = stage 35 | self.pixel_transform = transform 36 | self.drop_ratio = drop_ratio 37 | self.remove_background = remove_background 38 | self.use_greenscreen = use_greenscreen 39 | self.apply_crop_warping = apply_crop_warping 40 | with open(json_file, 'r') as f: 41 | self.celebvhq_info = json.load(f) 42 | 43 | self.use_gpu = use_gpu 44 | 45 | decord.bridge.set_bridge('torch') # Optional: This line sets decord to directly output PyTorch tensors. 46 | self.ctx = cpu() 47 | 48 | self.video_ids = list(self.celebvhq_info['clips'].keys()) 49 | 50 | random_video_id = random.choice(self.video_ids) 51 | driving = os.path.join(self.video_dir, f"{random_video_id}.mp4") 52 | print("driving:",driving) 53 | 54 | self.driving_vid_pil_image_list = self.load_and_process_video(driving) 55 | self.video_ids_star = list(self.celebvhq_info['clips'].keys()) 56 | 57 | random_video_id = random.choice(self.video_ids_star) 58 | driving_star = os.path.join(self.video_dir, f"{random_video_id}.mp4") 59 | print("driving_star:",driving_star) 60 | 61 | self.driving_vid_pil_image_list_star = self.load_and_process_video(driving_star) 62 | 63 | # TODO - make this more dynamic 64 | # driving = os.path.join(self.video_dir, "-2KGPYEFnsU_11.mp4") 65 | # self.driving_vid_pil_image_list = self.load_and_process_video(driving) 66 | # self.video_ids = ["M2Ohb0FAaJU_1"] # list(self.celebvhq_info['clips'].keys()) 67 | # self.video_ids_star = ["-1eKufUP5XQ_4"] # list(self.celebvhq_info['clips'].keys()) 68 | # driving_star = os.path.join(self.video_dir, "-2KGPYEFnsU_8.mp4") 69 | # self.driving_vid_pil_image_list_star = self.load_and_process_video(driving_star) 70 | 71 | def __len__(self) -> int: 72 | return len(self.video_ids) 73 | 74 | def warp_and_crop_face(self, image_tensor, video_name, frame_idx, transform=None, output_dir="output_images", warp_strength=0.01, apply_warp=False): 75 | # Ensure the output directory exists 76 | os.makedirs(output_dir, exist_ok=True) 77 | 78 | # Construct the file path 79 | output_path = os.path.join(output_dir, f"{video_name}_frame_{frame_idx}.png") 80 | 81 | # Check if the file already exists 82 | if os.path.exists(output_path): 83 | # Load and return the existing image as a tensor 84 | existing_image = Image.open(output_path).convert("RGBA") 85 | return to_tensor(existing_image) 86 | 87 | # Check if the input tensor has a batch dimension and handle it 88 | if image_tensor.ndim == 4: 89 | # Assuming batch size is the first dimension, process one image at a time 90 | image_tensor = image_tensor.squeeze(0) 91 | 92 | # Convert the single image tensor to a PIL Image 93 | image = to_pil_image(image_tensor) 94 | 95 | # Remove the background from the image 96 | img_byte_arr = io.BytesIO() 97 | image.save(img_byte_arr, format='PNG') 98 | img_byte_arr = img_byte_arr.getvalue() 99 | bg_removed_bytes = remove(img_byte_arr) 100 | bg_removed_image = Image.open(io.BytesIO(bg_removed_bytes)).convert("RGBA") 101 | 102 | # Convert the image to RGB format to make it compatible with face_recognition 103 | bg_removed_image_rgb = bg_removed_image.convert("RGB") 104 | 105 | # Detect the face in the background-removed RGB image using the numpy array 106 | face_locations = face_recognition.face_locations(np.array(bg_removed_image_rgb)) 107 | 108 | if len(face_locations) > 0: 109 | top, right, bottom, left = face_locations[0] 110 | 111 | # Automatically choose sweet spot to crop. 112 | # https://github.com/tencent-ailab/V-Express/blob/main/assets/crop_example.jpeg 113 | 114 | face_width = right - left 115 | face_height = bottom - top 116 | 117 | # Calculate the padding amount based on face size and output dimensions 118 | pad_width = int(face_width * 0.5) 119 | pad_height = int(face_height * 0.5) 120 | 121 | # Expand the cropping coordinates with the calculated padding 122 | left_with_pad = max(0, left - pad_width) 123 | top_with_pad = max(0, top - pad_height) 124 | right_with_pad = min(bg_removed_image.width, right + pad_width) 125 | bottom_with_pad = min(bg_removed_image.height, bottom + pad_height) 126 | 127 | # Crop the face region from the image with padding 128 | face_image_with_pad = bg_removed_image.crop((left_with_pad, top_with_pad, right_with_pad, bottom_with_pad)) 129 | 130 | # Crop the face region from the image without padding 131 | face_image_no_pad = bg_removed_image.crop((left, top, right, bottom)) 132 | 133 | if apply_warp: 134 | # Convert the face image to a numpy array 135 | face_array_with_pad = np.array(face_image_with_pad) 136 | face_array_no_pad = np.array(face_image_no_pad) 137 | 138 | # Generate random control points for thin-plate-spline warping 139 | rows_with_pad, cols_with_pad = face_array_with_pad.shape[:2] 140 | rows_no_pad, cols_no_pad = face_array_no_pad.shape[:2] 141 | src_points_with_pad = np.array([[0, 0], [cols_with_pad-1, 0], [0, rows_with_pad-1], [cols_with_pad-1, rows_with_pad-1]]) 142 | src_points_no_pad = np.array([[0, 0], [cols_no_pad-1, 0], [0, rows_no_pad-1], [cols_no_pad-1, rows_no_pad-1]]) 143 | dst_points_with_pad = src_points_with_pad + np.random.randn(4, 2) * (rows_with_pad * warp_strength) 144 | dst_points_no_pad = src_points_no_pad + np.random.randn(4, 2) * (rows_no_pad * warp_strength) 145 | 146 | # Create a PiecewiseAffineTransform object 147 | tps_with_pad = PiecewiseAffineTransform() 148 | tps_with_pad.estimate(src_points_with_pad, dst_points_with_pad) 149 | tps_no_pad = PiecewiseAffineTransform() 150 | tps_no_pad.estimate(src_points_no_pad, dst_points_no_pad) 151 | 152 | # Apply the thin-plate-spline warping to the face images 153 | warped_face_array_with_pad = warp(face_array_with_pad, tps_with_pad, output_shape=(rows_with_pad, cols_with_pad)) 154 | warped_face_array_no_pad = warp(face_array_no_pad, tps_no_pad, output_shape=(rows_no_pad, cols_no_pad)) 155 | 156 | # Convert the warped face arrays back to PIL images 157 | warped_face_image_with_pad = Image.fromarray((warped_face_array_with_pad * 255).astype(np.uint8)) 158 | warped_face_image_no_pad = Image.fromarray((warped_face_array_no_pad * 255).astype(np.uint8)) 159 | else: 160 | warped_face_image_with_pad = face_image_with_pad 161 | warped_face_image_no_pad = face_image_no_pad 162 | 163 | # Apply the transform if provided 164 | if transform: 165 | warped_face_image_with_pad = warped_face_image_with_pad.convert("RGB") 166 | warped_face_image_no_pad = warped_face_image_no_pad.convert("RGB") 167 | warped_face_tensor_with_pad = transform(warped_face_image_with_pad) 168 | warped_face_tensor_no_pad = transform(warped_face_image_no_pad) 169 | return warped_face_tensor_with_pad, warped_face_tensor_no_pad 170 | 171 | # Convert the warped PIL images back to tensors 172 | warped_face_image_with_pad = warped_face_image_with_pad.convert("RGB") 173 | warped_face_image_no_pad = warped_face_image_no_pad.convert("RGB") 174 | return to_tensor(warped_face_image_with_pad), to_tensor(warped_face_image_no_pad) 175 | 176 | else: 177 | return None, None 178 | 179 | 180 | def load_and_process_video(self, video_path: str) -> List[torch.Tensor]: 181 | # Extract video ID from the path 182 | video_id = Path(video_path).stem 183 | output_dir = Path(self.video_dir + "/" + video_id) 184 | output_dir.mkdir(exist_ok=True) 185 | 186 | processed_frames = [] 187 | tensor_frames = [] 188 | 189 | tensor_file_path = output_dir / f"{video_id}_tensors.npz" 190 | 191 | # Check if the tensor file exists 192 | if tensor_file_path.exists(): 193 | print(f"Loading processed tensors from file: {tensor_file_path}") 194 | with np.load(tensor_file_path) as data: 195 | tensor_frames = [torch.tensor(data[key]) for key in data] 196 | else: 197 | if self.apply_crop_warping: 198 | print(f"Warping + Processing and saving video frames to directory: {output_dir}") 199 | else: 200 | print(f"Processing and saving video frames to directory: {output_dir}") 201 | video_reader = VideoReader(video_path, ctx=self.ctx) 202 | for frame_idx in tqdm(range(len(video_reader)), desc="Processing Video Frames"): 203 | frame = Image.fromarray(video_reader[frame_idx].numpy()) 204 | state = torch.get_rng_state() 205 | # here we run the color jitter / random flip 206 | tensor_frame, image_frame = self.augmentation(frame, self.pixel_transform, state) 207 | processed_frames.append(image_frame) 208 | 209 | if self.apply_crop_warping: 210 | transform = transforms.Compose([ 211 | transforms.Resize((512, 512)), # get the cropped image back to this size - TODO support 256 212 | transforms.ToTensor(), 213 | ]) 214 | video_name = Path(video_path).stem 215 | 216 | # vanilla crop 217 | _,sweet_tensor_frame1 = self.warp_and_crop_face(tensor_frame, video_name, frame_idx, transform, apply_warp=False) 218 | # Save frame as PNG image 219 | # img = to_pil_image(tensor_frame1) 220 | # img.save(output_dir / f"{frame_idx:06d}.png") 221 | # tensor_frames.append(tensor_frame1) 222 | 223 | img = to_pil_image(sweet_tensor_frame1) 224 | img.save(output_dir / f"s_{frame_idx:06d}.png") 225 | tensor_frames.append(sweet_tensor_frame1) 226 | 227 | # vanilla crop + warp 228 | _,sweet_tensor_frame2 = self.warp_and_crop_face(tensor_frame, video_name, frame_idx, transform, apply_warp=True) 229 | # Save frame as PNG image 230 | # img = to_pil_image(tensor_frame2) 231 | # img.save(output_dir / f"w_{frame_idx:06d}.png") 232 | # tensor_frames.append(tensor_frame2) 233 | 234 | # Save frame as PNG image 235 | img = to_pil_image(sweet_tensor_frame2) 236 | img.save(output_dir / f"sw_{frame_idx:06d}.png") 237 | tensor_frames.append(sweet_tensor_frame2) 238 | else: 239 | # Save frame as PNG image 240 | image_frame.save(output_dir / f"{frame_idx:06d}.png") 241 | tensor_frames.append(tensor_frame) 242 | 243 | # Convert tensor frames to numpy arrays and save them 244 | np.savez_compressed(tensor_file_path, *[tensor_frame.numpy() for tensor_frame in tensor_frames]) 245 | print(f"Processed tensors saved to file: {tensor_file_path}") 246 | 247 | return tensor_frames 248 | 249 | def augmentation(self, images, transform, state=None): 250 | if state is not None: 251 | torch.set_rng_state(state) 252 | 253 | if isinstance(images, list): 254 | if self.remove_background: 255 | images = [self.remove_bg(img) for img in images] 256 | transformed_images = [transform(img) for img in tqdm(images, desc="Augmenting Images")] 257 | ret_tensor = torch.stack(transformed_images, dim=0) 258 | else: 259 | if self.remove_background: 260 | images = self.remove_bg(images) 261 | ret_tensor = transform(images) 262 | 263 | return ret_tensor, images 264 | 265 | def remove_bg(self, image): 266 | img_byte_arr = io.BytesIO() 267 | image.save(img_byte_arr, format='PNG') 268 | img_byte_arr = img_byte_arr.getvalue() 269 | bg_removed_bytes = remove(img_byte_arr) 270 | bg_removed_image = Image.open(io.BytesIO(bg_removed_bytes)).convert("RGBA") # Use RGBA to keep transparency 271 | 272 | if self.use_greenscreen: 273 | # Create a green screen background 274 | green_screen = Image.new("RGBA", bg_removed_image.size, (0, 255, 0, 255)) # Green color 275 | 276 | # Composite the image onto the green screen 277 | final_image = Image.alpha_composite(green_screen, bg_removed_image) 278 | else: 279 | final_image = bg_removed_image 280 | 281 | final_image = final_image.convert("RGB") # Convert to RGB format 282 | return final_image 283 | 284 | def save_video(self, frames, output_path, fps=30): 285 | print(f"Saving video with {len(frames)} frames to {output_path}") 286 | 287 | # Define the codec and create VideoWriter object 288 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # You can change 'mp4v' to other codecs if needed 289 | height, width, _ = np.array(frames[0]).shape 290 | out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) 291 | 292 | for frame in frames: 293 | frame = np.array(frame) 294 | if frame.shape[2] == 4: 295 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) 296 | out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) # Convert to BGR format 297 | 298 | out.release() 299 | print(f"Video saved to {output_path}") 300 | 301 | def process_video(self, video_path): 302 | processed_frames = self.process_video_frames(video_path) 303 | return processed_frames 304 | 305 | def process_video_frames(self, video_path: str) -> List[torch.Tensor]: 306 | video_reader = VideoReader(video_path, ctx=self.ctx) 307 | processed_frames = [] 308 | for frame_idx in tqdm(range(len(video_reader)), desc="Processing Video Frames"): 309 | frame = Image.fromarray(video_reader[frame_idx].numpy()) 310 | state = torch.get_rng_state() 311 | tensor_frame, image_frame = self.augmentation(frame, self.pixel_transform, state) 312 | processed_frames.append(image_frame) 313 | return processed_frames 314 | 315 | def __getitem__(self, index: int) -> Dict[str, Any]: 316 | video_id = self.video_ids[index] 317 | # Use next item in the list for video_id_star, wrap around if at the end 318 | video_id_star = self.video_ids_star[(index + 1) % len(self.video_ids_star)] 319 | vid_pil_image_list = self.load_and_process_video(os.path.join(self.video_dir, f"{video_id}.mp4")) 320 | vid_pil_image_list_star = self.load_and_process_video(os.path.join(self.video_dir, f"{video_id_star}.mp4")) 321 | 322 | sample = { 323 | "video_id": video_id, 324 | "source_frames": vid_pil_image_list, 325 | "driving_frames": self.driving_vid_pil_image_list, 326 | "video_id_star": video_id_star, 327 | "source_frames_star": vid_pil_image_list_star, 328 | "driving_frames_star": self.driving_vid_pil_image_list_star, 329 | } 330 | return sample 331 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## IMPORTANT 2 | My VASA hack project https://github.com/johndpope/vasa-1-hack has running /training code stage 1 (megaportraits) - with hot fixes 3 | [https://github.com/johndpope/VASA-1-hack/blob/main/train_stage_1.py 4 | ](https://github.com/johndpope/VASA-1-hack/commit/430947d9707777d2ed9d38183e523d31c13054eb) 5 | 6 | 7 | 8 | 9 | 10 | # MegaPortrait - SamsungLabs AI - Russia 11 | Implementation of Megaportrait using Claude Opus 12 | 13 | 14 | All models / code is in model.py 15 | 16 | 17 | ![Image](diagram.jpeg) 18 | 19 | 20 | memory debug 21 | ```shell 22 | mprof run train.py 23 | ``` 24 | or just 25 | ```shell 26 | python train.py 27 | ``` 28 | 29 | 30 | 31 | ### UPDATES 32 | 33 | - Save / restore checkpoint) specify in config ./configs/training/stage10base.yaml to restore checkpoint 34 | - auto crop video frames to sweet spot 35 | - tensorboard losses 36 | - LPIPS 37 | - additional imagepyramide from one shot view code for loss - (this broke things..) 38 | 39 | 40 | 41 | 42 | ### EmoDataset 43 | 44 | [warp / crop / spline / remove background / transforms](EmoDataset.md) 45 | 46 | ## Training Data (☢️ dont need this yet.) 47 | 48 | - **Total Videos:** 35,000 facial videos 49 | - **Total Size:** 40GB 50 | 51 | 52 | ### Training Strategy 53 | for now - to simplify problem - use the 4 videos in junk folder. 54 | once models are validated - can point the video_dir to above torrent 55 | ```yaml 56 | # video_dir: '/Downloads/CelebV-HQ/celebvhq/35666' 57 | video_dir: './junk' 58 | ``` 59 | the preprocessing is taking 1-2 mins for each video - I add some saving to npz format for faster reloading. 60 | 61 | 62 | ### Torrent Download 63 | 64 | You can download the dataset via the provided magnet link or by visiting [Academic Torrents](https://academictorrents.com/details/843b5adb0358124d388c4e9836654c246b988ff4). 65 | 66 | ```plaintext 67 | magnet:?xt=urn:btih:843b5adb0358124d388c4e9836654c246b988ff4&dn=CelebV-HQ&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=https%3A%2F%2Fipv6.academictorrents.com%2Fannounce.php 68 | ``` 69 | 70 | 71 | 72 | ### Implemented Functionality / Descriptions 73 | 74 | #### Base Model (`Gbase`) 75 | - **Description**: Responsible for creating the foundational neural head avatar at a medium resolution of \(512 x 512\). Uses volumetric features to encode appearance and latent descriptors to encode motion. 76 | - **Components**: 77 | - **Appearance Encoder (`Eapp`)**: Encodes the appearance of the source frame into volumetric features and a global descriptor. 78 | ```python 79 | class Eapp(nn.Module): 80 | # Architecture details omitted for brevity 81 | ``` 82 | - **Motion Encoder (`Emtn`)**: Encodes the motion from both source and driving images into head rotations, translations, and latent expression descriptors. 83 | ```python 84 | class Emtn(nn.Module): 85 | # Architecture details omitted for brevity 86 | ``` 87 | - **Warping Generators (`Wsrc_to_can` and `Wcan_to_drv`)**: Removes motion from the source and imposes driver motion onto canonical features. 88 | ```python 89 | class WarpGenerator(nn.Module): 90 | # Architecture details omitted for brevity 91 | ``` 92 | - **3D Convolutional Network (`G3D`)**: Processes canonical volumetric features. 93 | ```python 94 | class G3D(nn.Module): 95 | # Architecture details omitted for brevity 96 | ``` 97 | - **2D Convolutional Network (`G2D`)**: Projects 3D features into 2D and generates the output image. 98 | ```python 99 | class G2D(nn.Module): 100 | # Architecture details omitted for brevity 101 | ``` 102 | 103 | #### High-Resolution Model (`GHR`) 104 | - **Description**: Enhances the resolution of the base model output from \(512 \times 512\) to \(1024 \times 1024\) using a high-resolution dataset of photographs. 105 | - **Components**: 106 | - **Encoder**: Takes the base model output and produces a 3D feature tensor. 107 | ```python 108 | class EncoderHR(nn.Module): 109 | # Architecture details omitted for brevity 110 | ``` 111 | - **Decoder**: Converts the 3D feature tensor to a high-resolution image. 112 | ```python 113 | class DecoderHR(nn.Module): 114 | # Architecture details omitted for brevity 115 | ``` 116 | 117 | #### Student Model (`Student`) 118 | - **Description**: A distilled version of the high-resolution model for real-time applications. Trained to mimic the full model’s predictions but runs faster and is limited to a predefined number of avatars. 119 | - **Components**: 120 | - **ResNet18 Encoder**: Encodes the input image. 121 | ```python 122 | class ResNet18(nn.Module): 123 | # Architecture details omitted for brevity 124 | ``` 125 | - **Generator with SPADE Normalization Layers**: Generates the final output image. Each SPADE block uses tensors specific to an avatar. 126 | ```python 127 | class SPADEGenerator(nn.Module): 128 | # Architecture details omitted for brevity 129 | ``` 130 | 131 | #### Gaze and Blink Loss Model 132 | - **Description**: Computes the gaze and blink loss using a pretrained face mesh from MediaPipe and a custom network. The gaze loss uses MAE and MSE, while the blink loss uses binary cross-entropy. 133 | - **Components**: 134 | - **Backbone (VGG16)**: Extracts features from the eye images. 135 | ```python 136 | class VGG16Backbone(nn.Module): 137 | # Architecture details omitted for brevity 138 | ``` 139 | - **Keypoint Network**: Processes 2D keypoints. 140 | ```python 141 | class KeypointNet(nn.Module): 142 | # Architecture details omitted for brevity 143 | ``` 144 | - **Gaze Head**: Predicts gaze direction. 145 | ```python 146 | class GazeHead(nn.Module): 147 | # Architecture details omitted for brevity 148 | ``` 149 | - **Blink Head**: Predicts blink probability. 150 | ```python 151 | class BlinkHead(nn.Module): 152 | # Architecture details omitted for brevity 153 | ``` 154 | 155 | #### Training Functions 156 | - **`train_base(cfg, Gbase, Dbase, dataloader)`**: Trains the base model using perceptual, adversarial, and cycle consistency losses. 157 | ```python 158 | def train_base(cfg, Gbase, Dbase, dataloader): 159 | # Training code omitted for brevity 160 | ``` 161 | - **`train_hr(cfg, GHR, Dhr, dataloader)`**: Trains the high-resolution model using super-resolution objectives and adversarial losses. 162 | ```python 163 | def train_hr(cfg, GHR, Dhr, dataloader): 164 | # Training code omitted for brevity 165 | ``` 166 | - **`train_student(cfg, Student, GHR, dataloader)`**: Distills the high-resolution model into a student model for faster inference. 167 | ```python 168 | def train_student(cfg, Student, GHR, dataloader): 169 | # Training code omitted for brevity 170 | ``` 171 | 172 | #### Training Pipeline 173 | - **Data Augmentation**: Applies random horizontal flips, color jitter, and other augmentations to the input images. 174 | - **Optimizers**: Uses AdamW optimizer with cosine learning rate scheduling for both base and high-resolution models. 175 | - **Losses**: 176 | - **Perceptual Loss**: Matches the content and facial appearance between predicted and ground-truth images. 177 | - **Adversarial Loss**: Ensures the realism of predicted images using a multi-scale patch discriminator. 178 | - **Cycle Consistency Loss**: Prevents appearance leakage through the motion descriptor. 179 | 180 | #### Main Function 181 | - **Description**: Sets up the dataset and data loaders, initializes the models, and calls the training functions for base, high-resolution, and student models. 182 | - **Implementation**: 183 | ```python 184 | def main(cfg: OmegaConf) -> None: 185 | use_cuda = torch.cuda.is_available() 186 | device = torch.device("cuda" if use_cuda else "cpu") 187 | 188 | transform = transforms.Compose([ 189 | transforms.ToTensor(), 190 | transforms.Normalize([0.5], [0.5]), 191 | transforms.RandomHorizontalFlip(), 192 | transforms.ColorJitter() 193 | ]) 194 | 195 | dataset = EMODataset( 196 | use_gpu=use_cuda, 197 | width=cfg.data.train_width, 198 | height=cfg.data.train_height, 199 | n_sample_frames=cfg.training.n_sample_frames, 200 | sample_rate=cfg.training.sample_rate, 201 | img_scale=(1.0, 1.0), 202 | video_dir=cfg.training.video_dir, 203 | json_file=cfg.training.json_file, 204 | transform=transform 205 | ) 206 | 207 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4) 208 | 209 | Gbase = model.Gbase() 210 | Dbase = model.Discriminator() 211 | train_base(cfg, Gbase, Dbase, dataloader) 212 | 213 | GHR = model.GHR() 214 | GHR.Gbase.load_state_dict(Gbase.state_dict()) 215 | Dhr = model.Discriminator() 216 | train_hr(cfg, GHR, Dhr, dataloader) 217 | 218 | Student = model.Student(num_avatars=100) 219 | train_student(cfg, Student, GHR, dataloader) 220 | 221 | torch.save(Gbase.state_dict(), 'Gbase.pth') 222 | torch.save(GHR.state_dict(), 'GHR.pth') 223 | torch.save(Student.state_dict(), 'Student.pth') 224 | 225 | if __name__ == "__main__": 226 | config = OmegaConf.load("./configs/training/stage1-base.yaml") 227 | main(config) 228 | ``` 229 | 230 | 231 | rome/losses - cherry picked from 232 | https://github.com/SamsungLabs/rome 233 | 234 | 235 | 236 | 237 | 238 | wget 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 239 | extract to state_dicts 240 | 241 | 242 | # RT-GENE (Real-Time Gaze Estimation) - couldn't get working 243 | ```python 244 | git clone https://github.com/Tobias-Fischer/rt_gene.git 245 | cd rt_gene/rt_gene 246 | pip install . 247 | ``` 248 | -------------------------------------------------------------------------------- /configs/inference/stage1-base.yaml: -------------------------------------------------------------------------------- 1 | inference: 2 | checkpoint_path: './checkpoint_epoch100.pth' 3 | source_image: "path/to/source_image.png" 4 | driving_image: "path/to/driving_image.png" 5 | output_image: "output_base.jpg" -------------------------------------------------------------------------------- /configs/training/stage1-base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_width: 512 3 | train_height: 512 4 | sample_rate: 25 5 | n_sample_frames: 1 6 | n_motion_frames: 2 7 | training: 8 | frame_offset: 20 9 | checkpoint_path: './checkpoint_epoch100.pth' 10 | save_interval: 50 11 | log_interval: 100 12 | lambda_perceptual: 1.0 13 | lambda_adversarial: 1.0 14 | lambda_cosine: 1.0 15 | lambda_keypoints: 1.0 16 | lambda_gaze: 1.0 17 | lambda_supervised: 1.0 18 | lambda_unsupervised: 1.0 19 | batch_size: 24 20 | num_workers: 0 21 | lr: 1.0e-5 22 | base_epochs: 100 23 | hr_epochs: 50 24 | student_epochs: 100 25 | use_gpu_video_tensor: True 26 | prev_frames: 2 # Add this line to specify the number of previous frames to consider 27 | # video_dir: '/media/oem/12TB/Downloads/CelebV-HQ/celebvhq/35666' # point to your junk folder or 40gb - https://academictorrents.com/details/843b5adb0358124d388c4e9836654c246b988ff4 28 | video_dir: './junk' 29 | sample_rate: 25 30 | n_sample_frames: 100 31 | json_file: './data/overfit.json' 32 | 33 | 34 | w_per: 20 # perceptual loss 35 | w_adv: 1 # adversarial loss 36 | w_fm: 40 # feature matching loss 37 | w_cos: 2 # cycle consistency loss 38 | w_pairwise: 1 39 | w_identity: 1 40 | w_cyc: 1 41 | 42 | -------------------------------------------------------------------------------- /configs/training/stage2-hr.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_width: 512 3 | train_height: 512 4 | sample_rate: 25 5 | n_sample_frames: 1 6 | n_motion_frames: 2 7 | training: 8 | 9 | batch_size: 2 10 | num_workers: 0 11 | lr: 1.0e-5 12 | base_epochs: 100 13 | hr_epochs: 50 14 | student_epochs: 100 15 | use_gpu_video_tensor: True 16 | prev_frames: 2 # Add this line to specify the number of previous frames to consider 17 | video_dir: '/media/oem/12TB/Downloads/CelebV-HQ/celebvhq/35666' 18 | sample_rate: 25 19 | n_sample_frames: 100 20 | json_file: './data/overfit.json' -------------------------------------------------------------------------------- /data/driving_video.json: -------------------------------------------------------------------------------- 1 | {"meta_info": {"appearance_mapping": ["blurry", "male", "young", "chubby", "pale_skin", "rosy_cheeks", "oval_face", "receding_hairline", "bald", "bangs", "black_hair", "blonde_hair", "gray_hair", "brown_hair", "straight_hair", "wavy_hair", "long_hair", "arched_eyebrows", "bushy_eyebrows", "bags_under_eyes", "eyeglasses", "sunglasses", "narrow_eyes", "big_nose", "pointy_nose", "high_cheekbones", "big_lips", "double_chin", "no_beard", "5_o_clock_shadow", "goatee", "mustache", "sideburns", "heavy_makeup", "wearing_earrings", "wearing_hat", "wearing_lipstick", "wearing_necklace", "wearing_necktie", "wearing_mask"], "action_mapping": ["blow", "chew", "close_eyes", "cough", "cry", "drink", "eat", "frown", "gaze", "glare", "head_wagging", "kiss", "laugh", "listen_to_music", "look_around", "make_a_face", "nod", "play_instrument", "read", "shake_head", "shout", "sigh", "sing", "sleep", "smile", "smoke", "sneer", "sneeze", "sniff", "talk", "turn", "weep", "whisper", "wink", "yawn"]}, "clips": {"-2KGPYEFnsU_8": {"ytb_id": "-2KGPYEFnsU", "duration": {"start_sec": 102.6, "end_sec": 106.52}, "bbox": {"top": 0.0991, "bottom": 0.612, "left": 0.1234, "right": 0.412}, "attributes": {"appearance": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "action": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "emotion": {"sep_flag": false, "labels": "neutral"}}, "version": "v0.1"}}} 2 | 3 | 4 | -------------------------------------------------------------------------------- /data/overfit.json: -------------------------------------------------------------------------------- 1 | {"meta_info": {"appearance_mapping": ["blurry", "male", "young", "chubby", "pale_skin", "rosy_cheeks", "oval_face", "receding_hairline", "bald", "bangs", "black_hair", "blonde_hair", "gray_hair", "brown_hair", "straight_hair", "wavy_hair", "long_hair", "arched_eyebrows", "bushy_eyebrows", "bags_under_eyes", "eyeglasses", "sunglasses", "narrow_eyes", "big_nose", "pointy_nose", "high_cheekbones", "big_lips", "double_chin", "no_beard", "5_o_clock_shadow", "goatee", "mustache", "sideburns", "heavy_makeup", "wearing_earrings", "wearing_hat", "wearing_lipstick", "wearing_necklace", "wearing_necktie", "wearing_mask"], "action_mapping": ["blow", "chew", "close_eyes", "cough", "cry", "drink", "eat", "frown", "gaze", "glare", "head_wagging", "kiss", "laugh", "listen_to_music", "look_around", "make_a_face", "nod", "play_instrument", "read", "shake_head", "shout", "sigh", "sing", "sleep", "smile", "smoke", "sneer", "sneeze", "sniff", "talk", "turn", "weep", "whisper", "wink", "yawn"]}, "clips": {"M2Ohb0FAaJU_1": {"ytb_id": "M2Ohb0FAaJU", "duration": {"start_sec": 81.62, "end_sec": 86.17}, "bbox": {"top": 0.0, "bottom": 0.8815, "left": 0.1964, "right": 0.6922}, "attributes": {"appearance": [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], "action": [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], "emotion": {"sep_flag": false, "labels": "neutral"}}, "version": "v0.1"}}} -------------------------------------------------------------------------------- /diagram.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/diagram.jpeg -------------------------------------------------------------------------------- /draw_warps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | # from mpl_toolkits.mplot3d import Axes3D 4 | 5 | # # Create tensors for x, y, and z coordinates 6 | # x = torch.linspace(0, 10, 50) 7 | # y = torch.linspace(0, 10, 50) 8 | # X, Y = torch.meshgrid(x, y) 9 | # Z1 = torch.sin(X) + torch.randn(X.shape) * 0.2 10 | # Z2 = torch.sin(X + 1.5) + torch.randn(X.shape) * 0.2 11 | # Z3 = Z1 + Z2 12 | 13 | # # Create a figure and 3D axis 14 | # fig = plt.figure(figsize=(8, 6)) 15 | # ax = fig.add_subplot(111, projection='3d') 16 | 17 | # # Plot the dots with quiver for direction/flow 18 | # q1 = ax.quiver(X, Y, Z1, Z1, Z1, Z1, length=0.1, normalize=True, cmap='viridis', label='x_e,k') 19 | # q2 = ax.quiver(X, Y, Z2, Z2, Z2, Z2, length=0.1, normalize=True, cmap='plasma', label='R_d+c,k') 20 | # q3 = ax.quiver(X, Y, Z3, Z3, Z3, Z3, length=0.1, normalize=True, cmap='inferno', label='R_d+c,k + t_d') 21 | 22 | # # Set labels and title 23 | # ax.set_xlabel('x') 24 | # ax.set_ylabel('y') 25 | # ax.set_zlabel('z') 26 | # ax.set_title('PyTorch Tensor Plot (3D)') 27 | 28 | # # Add a legend 29 | # ax.legend() 30 | 31 | # # Display the plot 32 | # plt.show() 33 | 34 | 35 | import matplotlib.pyplot as plt 36 | from mpl_toolkits.mplot3d import axes3d 37 | 38 | import torch 39 | import numpy as np 40 | import torch.nn.functional as F 41 | 42 | 43 | k = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], 44 | dtype=torch.float32) 45 | base = F.affine_grid(k.unsqueeze(0), [1, 1, 2, 3, 4], align_corners=True) 46 | 47 | k = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0]], 48 | dtype=torch.float32) # rotate 49 | grid = F.affine_grid(k.unsqueeze(0), [1, 1, 2, 3, 4], align_corners=True) 50 | grid = grid - base 51 | grid = grid[0] 52 | 53 | D, H, W, _ = grid.shape 54 | 55 | fig = plt.figure() 56 | ax = fig.add_subplot(projection="3d") 57 | 58 | k, j, i = np.meshgrid( 59 | np.arange(0, D, 1), 60 | np.arange(0, H, 1), 61 | np.arange(0, W, 1), 62 | indexing="ij", 63 | ) 64 | 65 | u = grid[..., 0].numpy() 66 | v = grid[..., 1].numpy() 67 | w = grid[..., 2].numpy() 68 | 69 | ax.quiver(k, j, i, w, v, u, length=0.3) 70 | plt.show() -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import model 3 | from torchvision import transforms 4 | from PIL import Image 5 | import cv2 6 | import numpy as np 7 | import argparse 8 | from omegaconf import OmegaConf 9 | 10 | def load_image(image_path, transform): 11 | image = Image.open(image_path).convert("RGB") 12 | image = transform(image).unsqueeze(0) 13 | return image 14 | 15 | def inference_base(source_image_path, driving_image_path, Gbase, device): 16 | print("fyi - using normalize.") 17 | transform = transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize([0.5], [0.5]) 20 | ]) 21 | 22 | # Load source and driving images 23 | source_image = load_image(source_image_path, transform) 24 | driving_image = load_image(driving_image_path, transform) 25 | 26 | # Move images to device 27 | source_image = source_image.to(device) 28 | driving_image = driving_image.to(device) 29 | 30 | # Set Gbase to evaluation mode 31 | Gbase.eval() 32 | 33 | with torch.no_grad(): 34 | # Generate output frame 35 | output_frame = Gbase(source_image, driving_image) 36 | 37 | # Convert output frame to numpy array 38 | output_frame = output_frame.squeeze(0).cpu().numpy() 39 | output_frame = np.transpose(output_frame, (1, 2, 0)) 40 | output_frame = (output_frame + 1) / 2 41 | output_frame = (output_frame * 255).astype(np.uint8) 42 | 43 | # Convert BGR to RGB 44 | output_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB) 45 | 46 | return output_frame 47 | 48 | def main(cfg: OmegaConf): 49 | use_cuda = torch.cuda.is_available() 50 | device = torch.device("cuda" if use_cuda else "cpu") 51 | 52 | # Load pretrained base model 53 | Gbase = model.Gbase().to(device) 54 | 55 | # Specify paths to source and driving images 56 | # source_image_path = "./output_images/source_frame_0.png" 57 | # driving_image_path = "./output_images/driving_frame_0.png" 58 | # Load checkpoint 59 | checkpoint = torch.load(cfg.inference.checkpoint_path) 60 | Gbase.load_state_dict(checkpoint, strict=False) 61 | 62 | # Perform inference 63 | # output_frame = inference_base(source_image_path, driving_image_path, Gbase) 64 | output_frame = inference_base(cfg.inference.source_image, cfg.inference.driving_image, Gbase, device) 65 | 66 | # Save output frame 67 | cv2.imwrite(cfg.inference.output_image, output_frame) 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser(description="Inference script") 71 | parser.add_argument('--config', type=str, required=True, help='Path to the config file') 72 | args = parser.parse_args() 73 | 74 | config = OmegaConf.load(args.config) 75 | main(config) 76 | -------------------------------------------------------------------------------- /junk/-1eKufUP5XQ_4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/junk/-1eKufUP5XQ_4.mp4 -------------------------------------------------------------------------------- /junk/-2KGPYEFnsU_11.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/junk/-2KGPYEFnsU_11.mp4 -------------------------------------------------------------------------------- /junk/-2KGPYEFnsU_8.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/junk/-2KGPYEFnsU_8.mp4 -------------------------------------------------------------------------------- /junk/M2Ohb0FAaJU_1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/junk/M2Ohb0FAaJU_1.mp4 -------------------------------------------------------------------------------- /junk/download.torrent: -------------------------------------------------------------------------------- 1 | https://github.com/johndpope/Emote-hack/issues/1 2 | https://academictorrents.com/details/843b5adb0358124d388c4e9836654c246b988ff4 -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # performance metrics (L1, LPIPS, PSNR, SSIM, AKD, AED) 2 | import numpy as np 3 | import cv2 4 | import dlib # For example, using dlib for facial landmark detection 5 | import os 6 | import cv2 7 | import torch 8 | import numpy as np 9 | import lpips 10 | import skimage.metrics 11 | 12 | 13 | # Load pre-trained dlib model for facial landmark detection 14 | detector = dlib.get_frontal_face_detector() 15 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 16 | 17 | def extract_keypoints(image): 18 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 19 | faces = detector(gray) 20 | if len(faces) == 0: 21 | return None # No face detected 22 | face = faces[0] 23 | landmarks = predictor(gray, face) 24 | keypoints = np.array([(p.x, p.y) for p in landmarks.parts()]) 25 | return keypoints 26 | 27 | # Average Euclidean Distance 28 | def calculate_aed(pred, target): 29 | pred_keypoints = extract_keypoints(pred) 30 | target_keypoints = extract_keypoints(target) 31 | if pred_keypoints is None or target_keypoints is None: 32 | return None # Skip images without detected keypoints 33 | distance = np.linalg.norm(pred_keypoints - target_keypoints, axis=1) 34 | return np.mean(distance) 35 | 36 | def calculate_l1(pred, target): 37 | return torch.nn.functional.l1_loss(pred, target).item() 38 | 39 | def calculate_lpips(pred, target, lpips_model): 40 | return lpips_model(pred, target).item() 41 | 42 | def calculate_psnr(pred, target): 43 | return skimage.metrics.peak_signal_noise_ratio(target, pred) 44 | 45 | def calculate_ssim(pred, target): 46 | return skimage.metrics.structural_similarity(target, pred, multichannel=True) 47 | 48 | def load_image(filepath): 49 | img = cv2.imread(filepath) 50 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 51 | img = img / 255.0 # Normalize to [0, 1] 52 | return img 53 | 54 | def preprocess_image_for_lpips(img): 55 | # Convert image to PyTorch tensor and normalize to [-1, 1] as required by LPIPS 56 | img = torch.from_numpy(img).permute(2, 0, 1).float() # HWC to CHW 57 | img = img * 2 - 1 # Normalize to [-1, 1] 58 | return img.unsqueeze(0) # Add batch dimension 59 | 60 | 61 | def evaluate_metrics(output_dir, target_dir): 62 | lpips_model = lpips.LPIPS(net='alex') 63 | 64 | l1_scores = [] 65 | lpips_scores = [] 66 | psnr_scores = [] 67 | ssim_scores = [] 68 | akd_scores = [] 69 | aed_scores = [] 70 | 71 | for filename in os.listdir(output_dir): 72 | if filename.startswith("cross_reenactment_images") or filename.startswith("pred_frame"): 73 | pred_path = os.path.join(output_dir, filename) 74 | target_path = os.path.join(target_dir, filename) 75 | 76 | if os.path.exists(target_path): 77 | pred_img = load_image(pred_path) 78 | target_img = load_image(target_path) 79 | 80 | l1 = calculate_l1(torch.tensor(pred_img), torch.tensor(target_img)) 81 | lpips_score = calculate_lpips(preprocess_image_for_lpips(pred_img), preprocess_image_for_lpips(target_img), lpips_model) 82 | psnr = calculate_psnr(pred_img, target_img) 83 | ssim = calculate_ssim(pred_img, target_img) 84 | akd = calculate_akd(pred_img, target_img) 85 | aed = calculate_aed(pred_img, target_img) 86 | 87 | l1_scores.append(l1) 88 | lpips_scores.append(lpips_score) 89 | psnr_scores.append(psnr) 90 | ssim_scores.append(ssim) 91 | akd_scores.append(akd) 92 | if aed is not None: 93 | aed_scores.append(aed) 94 | 95 | return { 96 | "L1": np.mean(l1_scores), 97 | "LPIPS": np.mean(lpips_scores), 98 | "PSNR": np.mean(psnr_scores), 99 | "SSIM": np.mean(ssim_scores), 100 | "AKD": np.mean(akd_scores), 101 | "AED": np.mean(aed_scores) if aed_scores else None 102 | } 103 | 104 | 105 | output_directory = "path/to/output_images" 106 | target_directory = "path/to/target_images" 107 | 108 | metrics = evaluate_metrics(output_directory, target_directory) 109 | 110 | print(f"L1: {metrics['L1']}") 111 | print(f"LPIPS: {metrics['LPIPS']}") 112 | print(f"PSNR: {metrics['PSNR']}") 113 | print(f"SSIM: {metrics['SSIM']}") 114 | print(f"AKD: {metrics['AKD']}") 115 | print(f"AED: {metrics['AED']}") 116 | -------------------------------------------------------------------------------- /output_images/driving_frame_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/driving_frame_0.png -------------------------------------------------------------------------------- /output_images/driving_frame_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/driving_frame_1.png -------------------------------------------------------------------------------- /output_images/driving_frame_star_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/driving_frame_star_0.png -------------------------------------------------------------------------------- /output_images/driving_frame_star_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/driving_frame_star_1.png -------------------------------------------------------------------------------- /output_images/foreground_mask_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/foreground_mask_0.png -------------------------------------------------------------------------------- /output_images/masked_predicted_image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/masked_predicted_image_0.png -------------------------------------------------------------------------------- /output_images/output_frame_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/output_frame_0.png -------------------------------------------------------------------------------- /output_images/output_frame_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/output_frame_1.png -------------------------------------------------------------------------------- /output_images/source_frame_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/source_frame_0.png -------------------------------------------------------------------------------- /output_images/source_frame_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/source_frame_1.png -------------------------------------------------------------------------------- /output_images/source_frame_star_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/source_frame_star_0.png -------------------------------------------------------------------------------- /output_images/source_frame_star_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/source_frame_star_1.png -------------------------------------------------------------------------------- /output_images/warped_driving_frame_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/output_images/warped_driving_frame_0.png -------------------------------------------------------------------------------- /reference/EffectiveDeepNetworkHeadPose.tx: -------------------------------------------------------------------------------- 1 | An Effective Deep Network for Head Pose Estimation without Keypoints 2 | Chien Thai, Viet Tran, Minh Bui, Huong Ninh and Hai Tran 3 | Computer Vision Department, Optoelectronics Center, Viettel Aerospace Institute, Vietnam 4 | {chientv13, vietth5, minhbq6, huongnt382, haitt27}@viettel.com.vn 5 | Keywords: head pose estimation, knowledge distillation, convolutional neural network 6 | Abstract: Human head pose estimation is an essential problem in facial analysis in recent years that has a lot of computer 7 | vision applications such as gaze estimation, virtual reality, driver assistance. Because of the importance of the 8 | head pose estimation problem, it is necessary to design a compact model to resolve this task in order to reduce 9 | the computational cost when deploying on facial analysis-based applications such as large camera surveillance 10 | systems, AI cameras while maintaining accuracy. In this work, we propose a lightweight model that effectively 11 | addresses the head pose estimation problem. Our approach has two main steps. 1) We first train many teacher 12 | models on the synthesis dataset - 300W-LPA to get the head pose pseudo labels. 2) We design an architecture 13 | with the ResNet18 backbone and train our proposed model with the ensemble of these pseudo labels via the 14 | knowledge distillation process. To evaluate the effectiveness of our model, we use AFLW-2000 and BIWI - 15 | two real-world head pose datasets. Experimental results show that our proposed model significantly improves 16 | the accuracy in comparison with the state-of-the-art head pose estimation methods. Furthermore, our model 17 | has the real-time speed of ∼300 FPS when inferring on Tesla V100. 18 | 1 INTRODUCTION 19 | Head pose estimation (HPE) is an important problem in facial analysis that has been extensively researched in recent years. Its application can be widely 20 | observed in lots of intelligent computer vision systems including virtual reality (Kumar et al., 2017), 21 | driver assistance (Schwarz et al., 2017; MurphyChutorian et al., 2007), gaze estimation (MurphyChutorian and Trivedi, 2008), human-computer interaction (Seemann et al., 2004; Wang et al., 2019) and 22 | smart city surveillance. 23 | The objective of head pose estimation is to accurately identify the orientation of heads of individuals found in images. Existing methods to solve 24 | this problem can be divided into two primary categories: landmark-based approaches (Cao et al., 2014; 25 | Lathuili`ere et al., 2017; Fanelli et al., 2011; Xiong 26 | and De la Torre, 2015; Sun et al., 2013; Xin et al., 27 | 2021; Bulat and Tzimiropoulos, 2017; DeMenthon 28 | and Davis, 1995) and landmark-free approach (Ruiz 29 | et al., 2018; Yang et al., 2019; Zhou and Gregson, 30 | 2020; Chang et al., 2017). Landmark-based methods 31 | use facial keypoints extracted by landmark detectors 32 | to regress the head pose angle. Recently, these approaches have achieved remarkable results since the 33 | usage of deep neural networks has greatly enhanced 34 | the quality of landmark detectors. However, the problem remains challenging due to the fact that not only a 35 | minor error of landmark detectors may adversely affect the head pose estimation but learning the relation between the geometric distribution of facial landmarks and head poses is not a trivial task. Furthermore, using landmark detection as a preprocessing 36 | step imposes a computational burden for the whole 37 | process of estimating head angle which hinders its usage for real-time applications. Landmark-free methods, on the other hand, directly predict the head poses 38 | from images without detecting facial keypoints which 39 | results in their fast execution time. 40 | In addition to the above approaches, some works 41 | utilize depth information from depth cameras (Meyer 42 | et al., 2015; Fanelli et al., 2011; Mukherjee and 43 | Robertson, 2015; Martin et al., 2014). Although this 44 | approach provides a prominent result, it still has some 45 | limitations. The depth cameras are sensitive to illumination change and light conditions so that they often yield substandard results in an uncontrolled environment. Moreover, they are very expensive and use 46 | more storage and transfer time, so they are often impractical for real-time applications. 47 | Because of the importance of the head pose estimation problem and in order to minimize the processing time of the model when deploying on large 48 | arXiv:2210.13705v1 [cs.CV] 25 Oct 2022systems or embedded platforms, our goal is to design 49 | a lightweight architecture that solves this task while 50 | still guaranteeing remarkable performance. For having a compact and simple model, our network uses 51 | ResNet18 architecture as a backbone. The contributions of our work can be summarized as follows: 52 | • We address a major mistake found in HopeNet 53 | (Ruiz et al., 2018) in which annotated face boxes 54 | are mislabeled. We prove that correcting those 55 | mislabeled boxes can significantly improve the 56 | accuracy of the head pose estimation task. 57 | • An end-to-end deep architecture designed to solve 58 | head pose estimation problem is proposed. A 59 | lightweight model is trained to this task via the 60 | knowledge distillation process. 61 | • Experiments are conducted to evaluate the performance of our method on two challenging 62 | head pose datasets (BIWI and AFLW-2000). 63 | Our method achieves state-of-the-art performance 64 | when evaluating on the head pose dataset. 65 | The rest of the paper is organized as follows: Section 2 puts forward some related works on head pose 66 | estimation problem. In section 3, we present our proposed method. Section 4 discusses the datasets, experiments, results, and ablation study. Finally, the 67 | conclusion and future work are discussed in Section 68 | 5. 69 | 2 RELATED WORKS 70 | Convolutional neural networks (CNNs) are widely 71 | used in computer vision tasks and gradually replace 72 | the traditional image processing methods. CNN is designed to automatically learn the spatial features of 73 | the image by using convolution kernels. With many 74 | convolutional layers, deep networks can extract highlevel semantic features. He et al. (He et al., 2016) propose the Residual Network to train the much deeper 75 | convolutional neural network. ResNet uses a skip 76 | connection between the current layer and the previous 77 | layer which can learn the identity mapping and solve 78 | the vanishing gradient problem. Because of its powerful and simple architecture, ResNet and its variants 79 | (Xie et al., 2017; Zhang et al., 2020; Gao et al., 2019) 80 | are widely used in many computer vision applications 81 | and deliver high performance. 82 | Human head pose estimation has been researched over the past 25 years with many different 83 | approaches. Appearance Template (Niyogi and Freeman, 1996; Beymer, 1994; Sherrah et al., 2001; Ng 84 | and Gong, 2002; Sherrah et al., 1999) is the method 85 | that compares the input image with a set of labeled 86 | templates and assigns it to the most similar template. 87 | Detector arrays (Huang et al., 1998; Zhang et al., 88 | 2006; Jones and Viola, 2003) estimate head pose by 89 | training multiple face detectors for the different discrete poses. 90 | Many approaches are based on facial landmarks 91 | from the input image to estimate the head pose. With 92 | the progress of landmarks detection, landmark-based 93 | methods demonstrate superior performance. Dementhon et al. (DeMenthon and Davis, 1995) proposed Pose from Orthography and Scaling with Iterations which determines the head pose by 3D computer vision techniques for the given 2D face landmarks. FAN (Bulat and Tzimiropoulos, 2017) using deep neural network to estimate 3D face models. 94 | EVA-GCN (Xin et al., 2021) constructs a landmarkconnection graph and leverages the Graph Convolution Network (Yan et al., 2018) to learn the nonlinear 95 | relationships between head poses and distribution of 96 | facial keypoints. 97 | Multi-task methods combine the head pose estimation problem with other related facial analysis 98 | problems, such as face detection, keypoints detection. 99 | Some works show that learning with related tasks 100 | yields better results than learning individual tasks independently (Chen et al., 2014; Kumar et al., 2017; 101 | Zhu and Ramanan, 2012; Ranjan et al., 2017b). KEPLER (Kumar et al., 2017) predicts face detection 102 | and pose estimation jointly by using Heatmap-CNN 103 | to capture structured global and local features. Hyperface (Ranjan et al., 2017a) presents a convolutional neural network for simultaneous face detection, 104 | landmarks localization, pose estimation, and gender 105 | recognition. 106 | Gu et al. (Gu et al., 2017) proposed a dynamic 107 | facial analysis that uses a recurrent neural network. 108 | They improve head pose estimation and facial landmarks localization by leveraging the time dimension 109 | from videos instead of a single frame. 110 | For accurate head pose estimation, some methods 111 | utilize 3D information of depth images. Meyer et al. 112 | (Meyer et al., 2015) perform head pose estimation by 113 | registering 3D morphable models to depth images, using the particle swarm optimization and the iterative 114 | closest point algorithm. Fanelli et al. (Fanelli et al., 115 | 2011) using Random Regression Forests to regress the 116 | head pose estimation of depth images. 117 | Recent works directly predict the Euler angles 118 | from a single RGB image by using a deep neural network and achieve prominent performance. HopeNet 119 | (Ruiz et al., 2018) proposed a multi-loss framework 120 | that combines binned pose classification and regression loss for each Euler angle. By using a very stable 121 | softmax layer and cross-entropy for binned classifica-Figure 1: The overview of the head pose model. The original image is passed through the face detector to get the bounding 122 | box of the objective face. The detected face is padded to a squared image and resized to 112x112. The head pose model 123 | extract 62 dimensions distribution vector for the given image. The predicted pose is calculated by the expectation of this 124 | vector. For each Euler angle, the classification loss is the cross-entropy loss between distribution vector and one-hot vector, 125 | the regression loss is the mean square error of ground truth and predicted pose. 126 | tion loss, the network obtained robust neighborhood 127 | prediction of the head pose. FSA-Net (Yang et al., 128 | 2019) employs the soft stagewise regression scheme 129 | by training classification and regression objectives of 130 | the features from multiple stages. It provides a compact model and accurate prediction. WHENet (Zhou 131 | and Gregson, 2020) proposed wrapped loss to estimate the full 360-degree range of yaw angle. Our 132 | proposed network has similar architecture to HopeNet 133 | (Ruiz et al., 2018), but has a smaller model size and 134 | achieves better performance on two challenging head 135 | pose datasets - BIWI and AFLW-2000. 136 | 3 PROPOSED METHOD 137 | In this section, we describe the major disadvantage of 138 | previous work and the method to mitigate this problem. After that, we explain the proposed method to 139 | construct an effective head pose estimation model via 140 | knowledge distillation process. 141 | The head pose estimation problem can be mathematically formulated as: Given a set of training images X = {xi|i = 1..N} and ground truth Y = {yi|i = 142 | 1..N}, where N is number of images, and yi is 3D vector of image xi corresponding to three Euler angles 143 | (yaw, pitch, roll), the goal is to find a function F so 144 | that the absolute difference between F(x) and the real 145 | head pose y for the given image x as small as possible. 146 | Inspired by HopeNet (Ruiz et al., 2018), we design a network using a multi-loss framework to solve 147 | this problem. HopeNet casts the regression problem 148 | of head pose estimation as a classification problem by 149 | dividing the poses range into 66 bins, each bin contains 3 units of degree. The predicted pose is the expected value of classes distribution. 150 | Further investigating HopeNet, we found that it is 151 | the preprocessing data that hinders its performance. 152 | They loosely crop around the bounding box of a face 153 | on the image and resize the cropped image to 224x224 154 | before fitting the model. Because the height of the 155 | face bounding box is often longer than the width, it 156 | slightly changes the real head pose and causes a negative effect on the training and testing phase. 157 | To mitigate this problem, we padded the bounding 158 | box of face to squared shape. Given a bounding box 159 | (x1, y1, x2, y2), the padding size k is calculated by |x2 160 | – x1 – y2 + y1| (the absolute difference between width 161 | and height). If the height h = x2 – x1 is longer than the 162 | width w = y2 – y1, the new coordinates of bounding 163 | box (x′ 164 | 1, y′ 165 | 1, x′ 166 | 2, y′ 167 | 2) are: 168 | x′ 169 | 1 = x1 170 | x′ 171 | 2 = x2 172 | y′ 173 | 1 = y1 − [k/2] 174 | y′ 175 | 2 = y2 + [k/2] 176 | and vice versa. After getting the square image of 177 | faces, we resize it to (112, 112) in order to decrease 178 | the computation cost when training and inference. 179 | Unlike HopeNet, we divide the poses range from 180 | -93 to 93 into 62 bins for each Euler angle. The classification loss of angle is cross-entropy loss between 181 | softmax output of model and pose’s corresponding 182 | one-hot vector: 183 | Langle 184 | cls = 185 | N 186 | ∑ 187 | i=1 188 | y′i ∗ log( ˆyi) (1) 189 | where y′i and ˆyi are respectively one-hot vector of 190 | pose and predicted softmax output for given input xi. 191 | The predicted pose of xi is expected values of softmax output that is denoted by ri. The regression loss 192 | of angle is mean squared error between the ground 193 | truth labels yi and the predicted pose ri: 194 | Langle 195 | reg = 1 196 | N 197 | N 198 | ∑ 199 | i=1 200 | ‖ri − yi‖2 (2) 201 | The total loss is composed by three separate 202 | losses, each loss is calculated by the sum of classification and regression loss of angle, as following: 203 | L = ∑ 204 | angle 205 | Langle 206 | cls + Langle 207 | reg (3)Figure 2: The overview of proposed method. The student model using ResNet18 backbone. The head pose loss is the sum of 208 | Kullback-Leibler Divergence loss between softmax output of student model and ensemble output of head pose teacher models 209 | on each of yaw, pitch, roll angle. The total loss is sum of distillation loss of three Euler angles 210 | where angle ∈ {yaw, pitch, roll} 211 | The above method uses hard labels to train head 212 | pose estimation models. Inspired by (Hinton et al., 213 | 2015), we use knowledge distillation to construct a 214 | compact model while enhancing the performance of 215 | this task. Our network uses ResNet18 (He et al., 216 | 2016) as the backbone, a simple and small architecture which is trained to match the output of head 217 | pose teacher models (pseudo label). With supervised 218 | learning, models are trained to match the same labels but with the different initiation and architectures, 219 | they will focus on distinctive features. So, we ensemble outputs of several strong head pose models to get 220 | more informative teacher features. 221 | Given Nteacher head pose models, we ensemble 222 | by calculating mean regression outputs of them. It 223 | is equal to the expected value of mean softmax outputs of these models. So, the output after ensemble n 224 | teacher head pose models is: 225 | yens 226 | i = 1 227 | Nteacher 228 | Nteacher 229 | ∑ 230 | j=1 231 | ˆy j 232 | i (4) 233 | where ˆy j 234 | i is softmax ouput of head pose teacher model 235 | j for given image xi 236 | The loss function for head pose task is KullbackLeibler Divergence between softmax output of student model ˆyt and output ensemble of n teacher models yens: 237 | Lhead pose = − 238 | N 239 | ∑ 240 | i=1 241 | yens 242 | i ∗ log( ˆyt 243 | i 244 | yens 245 | i 246 | ) (5) 247 | Because head pose estimation is a challenging task, we found that the stronger model with a 248 | lot of parameters and computation cost, the more 249 | model’s capacity to achieve good results. Base on 250 | the performance on ImageNet dataset (Deng et al., 251 | 2009), we train three head pose teacher models from 252 | scratch whose backbones are chosen respectively as 253 | ResNet101 (He et al., 2016), BotNet101 (Srinivas 254 | et al., 2021), and Res2Net101 (Gao et al., 2019). After that, we train a head pose model with backbone 255 | ResNet18 by the aforementioned head pose knowledge distillation strategy. 256 | In our experiment, we observed that the big models (teacher models) often give larger probabilities to 257 | the bins in the proximity of the truth bin and smaller 258 | scores to the ones far away. This is valuable information (i.e. the faces in a bin are more likely the faces 259 | in its neighbor bins) but it has very little effect on the 260 | cross entropy cost function during training if the probabilities are so close to zero. This means the soft targets of the teacher models attain a variety of information than one-hot labels, which helps the small model 261 | (student model) learn easily. So, we argue that the 262 | distilled model can preserve the generalization of the 263 | teacher models and reaches highly accurate results. 264 | 4 Experimental results 265 | In this section, we describe the datasets for training and testing, implementation, results, comparisons 266 | with other state-of-the-art methods and the ablation 267 | study. 268 | 4.1 Dataset 269 | Headpose dataset: In our experiment, we use three 270 | popular datasets for the head pose estimation problem: 300W-LPA (Hsu et al., 2019), AFLW-2000 (Zhu 271 | and Ramanan, 2012), and BIWI (Fanelli et al., 2011) 272 | datasets. 300W-LPA is a synthetically expanded 273 | dataset that provides over 350000 images across largeTable 1: Mean average error of Euler angles across both state-of-the-art landmark-based and landmark-free methods on the 274 | BIWI and AFLW2000 dataset 275 | BIWI AFLW2000 276 | Model Yaw Pitch Roll MAE Yaw Pitch Roll MAE 277 | KEPLER (Kumar et al., 2017) 8.80 17.3 16.2 13.9 - - - - 278 | FAN (Bulat and Tzimiropoulos, 2017) 8.53 7.48 7.63 7.89 6.36 12.3 8.71 9.12 279 | Dlib (Kazemi and Sullivan, 2014) 16.8 13.8 6.19 12.2 23.1 13.6 10.5 15.8 280 | 3DDFA (Zhu et al., 2016) - - - - 5.40 8.53 8.25 7.39 281 | EVA-GCN (Xin et al., 2021) 4.01 4.78 2.98 3.92 4.46 5.34 4.11 4.64 282 | HopeNet (α = 2) (Ruiz et al., 2018) 5.17 6.98 3.39 5.18 6.47 6.56 5.44 6.16 283 | HopeNet (α = 1) (Ruiz et al., 2018) 4.81 6.61 3.27 4.90 6.92 6.64 5.67 6.41 284 | SSR-Net-MD (Yang et al., 2018) 4.49 6.31 3.61 4.65 5.14 7.09 5.89 6.01 285 | FSA-Caps-Fusion (Yang et al., 2019) 4.27 4.96 2.76 4.00 4.50 6.08 4.64 5.07 286 | WHENet-V (Zhou and Gregson, 2020) 3.60 4.10 2.73 3.48 4.44 5.75 4.31 4.83 287 | EHPNet (Ours) 3.68 4.03 2.57 3.43 3.23 5.54 3.88 4.15 288 | Figure 3: Some examples of face image from the datasets. 289 | The first row is from the 300W-LPA (Hsu et al., 2019) 290 | which is a synthetically dataset. The second row and third 291 | row are respectively from the AFLW-2000 (Zhu et al., 2016) 292 | and BIWI (Fanelli et al., 2011) - two real-world datasets 293 | poses. The AFLW-2000 dataset provides head pose 294 | ground truth and corresponds to 68 landmark points 295 | among 2000 3-D face images. Images in the AFLW2000 dataset have large pose annotation and various 296 | lighting conditions. BIWI dataset uses a Kinect v2 297 | device to record RGB-D video of different subjects. 298 | It contains 24 videos of 20 subjects across different 299 | head poses. There are roughly 15000 samples in this 300 | dataset, each sample contains RGB and depth images, 301 | and pose annotations were created by using depth information. 302 | Following HopeNet, we use 300-LPA for training 303 | while testing on AFLW-2000 and BIWI - two realworld datasets. In our case, we only use RGB images for training in these datasets. We run RetinaFace 304 | (Deng et al., 2019) on all images to get the coordinate 305 | of the bounding box of faces. 306 | 4.2 Implementation 307 | For better estimating on low-resolution face images, 308 | we augment the head pose training dataset by random 309 | downsampling and upsampling to original image size, 310 | randomly adjust brightness and contrast, blur by random Gaussian kernel. We randomly flip the image 311 | and relabel the yaw and roll angle of the flip image to 312 | -yaw and -roll to get more training data. 313 | We use Pytorch for implementing the proposed 314 | network. We use 100 epochs to train the teacher networks with hard labels and 200 epochs for the knowledge distillation process. The chosen optimizer is 315 | Adam with initial learning-rate is 1e-4. The learning rate is reduced by the cosine annealing strategy. 316 | The experiments are performed on a computer with a 317 | Tesla V100 GPU. 318 | 4.3 Results 319 | We compare our proposed network with other stateof-the-art head pose estimation methods on BIWI 320 | and AFLW datasets. KEPLER (Kumar et al., 2017), 321 | FAN (Bulat and Tzimiropoulos, 2017), Dlib (Kazemi 322 | and Sullivan, 2014) and EVA-GCN (Xin et al., 323 | 2021) are landmark-based methods. KEPLER (Kumar et al., 2017) uses a modified GoogleNet to pre-Figure 4: The scatter plot between yaw, pitch, roll values and errors on AFLW-2000 dataset 324 | Figure 5: The scatter plot between yaw, pitch, roll values and errors on BIWI Dataset 325 | dict facial landmark points and pose at the same 326 | time. Dlib (Kazemi and Sullivan, 2014) is a face 327 | library that uses an ensemble of regression trees 328 | to detect landmarks. FAN (Bulat and Tzimiropoulos, 2017) is a state-of-the-art landmark detection 329 | method. EVA-GCN (Xin et al., 2021) is a stateof-the-art landmark-based method which constructs a 330 | landmark-connection graph and leverages the Graph 331 | Convolution Network (Yan et al., 2018) to learn the 332 | nonlinear relationships between head poses and distribution of facial keypoints. HopeNet (Ruiz et al., 333 | 2018), FSANet (Yang et al., 2019) and WHENet 334 | (Zhou and Gregson, 2020) are landmark free methods which treat the regression problem as classification problem by dividing the poses range to different 335 | classes. The α coefficient of HopeNet is the weight of 336 | the regression losses. 337 | Table 1 shows the comparisons of our proposed 338 | network with these above models respectively on 339 | BIWI and AFLW-2000 datasets. The evaluation metric is the mean absolute error of Euler angles. 340 | As shown in Table 1, our proposed EHPNet 341 | achieves state-of-the-arts on both AFLW-2000 and 342 | BIWI datasets. It outperforms the previous state-ofthe-art WHENet by 14.9% and 1.15%, respectively. 343 | EHPNet has a similar architecture with WHENet and 344 | HopeNet but uses a smaller model and image size. 345 | Even so, it has more significant improvement compared to WHENet and HopeNet. Furthermore, our 346 | model has a real-time speed of 300FPS when inferring on Tesla V100. 347 | Figure 6 shows some visualization of predicted 348 | images on BIWI, AFLW-2000. Besides achieving 349 | good results on images with variant poses and various lighting conditions, the network can also predict 350 | well on face images with high occlusion. 351 | Figure 6: Results of the proposed network. The blue line indicates towards the front of the face, the green line pointing 352 | downward direction and the red line pointing to the side. 353 | The first row is the prediction on the BIWI dataset. The 354 | second row is the estimation result on images with various 355 | lighting conditions of AFLW-2000. 356 | The scatter diagrams in Figure 4 and Figure 5 357 | show the influence of the pose’s value on prediction 358 | results for each angle on the AFLW-2000 and BIWI 359 | datasets. On AFLW-2000, the yaw angle has a smaller 360 | mean absolute error than pitch and roll angles and hasTable 2: The impact of different backbone and distillation training on head pose estimation models. The evaluation metric is 361 | the mean absolute error of Euler angles. 362 | BIWI AFLW2000 363 | Backbone Yaw Pitch Roll MAE Yaw Pitch Roll MAE 364 | ResNet18 3.969 4.849 2.869 3.897 3.785 5.642 4.238 4.555 365 | ResNet101 3.680 3.945 2.755 3.460 3.249 5.276 3.821 4.115 366 | BotNet101 3.876 4.066 2.528 3.489 3.559 5.109 3.697 4.135 367 | Res2Net101 3.827 3.939 2.669 3.478 3.223 5.080 3.556 3.953 368 | Ensemble 3.688 3.859 2.508 3.352 3.169 5.009 3.560 3.913 369 | Distilled ResNet18 3.683 4.033 2.571 3.429 3.226 5.345 3.876 4.148 370 | a stable prediction for pose range. For the pitch and 371 | roll angles, the model tends to predict well if the value 372 | of pose is as close to 0. As shown in Figure 4, there is 373 | some prediction that has a very big error. We find that 374 | this happens because the head poses on AFLW-2000 375 | are provided by using 3D landmarks, so some examples can have a big difference in a pose if viewed as 376 | RGB images. On the BIWI dataset, the discrepancy 377 | between ground truth and prediction is not significant 378 | but the trending is slightly changed. The model predicts better on yaw and roll angle. The higher pose’s 379 | value of pitch angle leads to face occlusion in the image and makes the model confused. As shown in Figure 5, many samples from the BIWI dataset don’t follow the trend. For example, a sample which has a 380 | yaw angle value close to zero has a maximum error. 381 | In our experiments, we observed that, although it has 382 | a small value of the yaw angle, it has a large pitch and 383 | roll so the face can be occluded, and leads to wrong 384 | predictions. 385 | 4.4 Ablation study 386 | We have conducted the ablation study when changing 387 | the backbone and using a pseudo label from teacher 388 | models for the head pose estimation task. As shown in 389 | Table 2, the result has significant improvement when 390 | using padding instead of resizing the cropped face image like HopeNet. By using ensemble, the mean absolute errors are slightly decreased on both BIWI and 391 | AFLW datasets. The small head pose model achieves 392 | better accuracy when training via the knowledge distillation process. With ResNet18 as a backbone, the 393 | pose model using pseudo labels from the output of ensemble many teacher models is better than the same 394 | model using the hard label. The distilled head pose 395 | model has equivalent results to its teacher, even better. 396 | In our experiment, we observed that these models can predict the same result for the yaw angle. 397 | But for pitch and roll angle, the complex head pose 398 | model works better. Among three teacher models, 399 | each one can predict better at a specific pose interval 400 | (i.e. model ResNet101 achieves the smallest error of 401 | yaw with 3.680, while Res2Net101 outperforms the 402 | two others with the pitch error of 3.939 and the last 403 | one BotNet101 attains the best roll error of 2.528), 404 | but worse at the others. By preserving the generalization of each model, the ensemble results have a stable 405 | prediction on the poses range. Overall, training the 406 | baseline model with hard targets leads to severe overfitting, whereas training the same model with ensemble soft targets is able to recover better generalization 407 | and achieves competitive results. 408 | 5 CONCLUSIONS 409 | In this paper, we have presented an EHPNet, which 410 | can directly, accurately, and robustly predict the 411 | head rotation from a single RGB image. This is 412 | achieved by mitigating the disadvantages of previous work, along with distilling the knowledge from 413 | many robust head pose estimation teacher models. 414 | By using ResNet18 architecture as a backbone, the 415 | model is compact and usable in many computer vision applications. The proposed network outperforms 416 | both landmark-based and landmark-free methods and 417 | achieves state-of-the-art results on both the AFLW2000 and BIWI datasets, with higher respectively 418 | 14.9% and 1.15% than WHENet. 419 | In the future, we would like to use a lower computation network as well as reduce the input image 420 | resolution. Besides, more effective knowledge dis-tillation techniques will be used to help the student 421 | model achieve better accuracy. 422 | REFERENCES 423 | Beymer, D. (1994). Face recognition under varying pose. 424 | In CVPR, volume 94, page 137. Citeseer. 425 | Bulat, A. and Tzimiropoulos, G. (2017). How far are 426 | we from solving the 2d & 3d face alignment problem?(and a dataset of 230,000 3d facial landmarks). 427 | In Proceedings of the IEEE International Conference 428 | on Computer Vision, pages 1021–1030. 429 | Cao, X., Wei, Y., Wen, F., and Sun, J. (2014). Face alignment by explicit shape regression. International journal of computer vision, 107(2):177–190. 430 | Chang, F.-J., Tuan Tran, A., Hassner, T., Masi, I., Nevatia, 431 | R., and Medioni, G. (2017). Faceposenet: Making a 432 | case for landmark-free face alignment. In Proceedings of the IEEE International Conference on Computer Vision Workshops, pages 1599–1608. 433 | Chen, D., Ren, S., Wei, Y., Cao, X., and Sun, J. (2014). 434 | Joint cascade face detection and alignment. In European conference on computer vision, pages 109–122. 435 | Springer. 436 | DeMenthon, D. F. and Davis, L. S. (1995). Model-based 437 | object pose in 25 lines of code. International journal 438 | of computer vision, 15(1-2):123–141. 439 | Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and FeiFei, L. (2009). Imagenet: A large-scale hierarchical 440 | image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. 441 | Ieee. 442 | Deng, J., Guo, J., Zhou, Y., Yu, J., Kotsia, I., and Zafeiriou, 443 | S. (2019). Retinaface: Single-stage dense face localisation in the wild. arXiv preprint arXiv:1905.00641. 444 | Fanelli, G., Weise, T., Gall, J., and Van Gool, L. (2011). 445 | Real time head pose estimation from consumer depth 446 | cameras. In Joint pattern recognition symposium, 447 | pages 101–110. Springer. 448 | Gao, S., Cheng, M.-M., Zhao, K., Zhang, X.-Y., Yang, M.H., and Torr, P. H. (2019). Res2net: A new multi-scale 449 | backbone architecture. IEEE transactions on pattern 450 | analysis and machine intelligence. 451 | Gu, J., Yang, X., De Mello, S., and Kautz, J. (2017). Dynamic facial analysis: From bayesian filtering to recurrent neural network. In Proceedings of the IEEE 452 | conference on computer vision and pattern recognition, pages 1548–1557. 453 | He, K., Zhang, X., Ren, S., and Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of 454 | the IEEE conference on computer vision and pattern 455 | recognition, pages 770–778. 456 | Hinton, G., Vinyals, O., and Dean, J. (2015). Distilling 457 | the knowledge in a neural network. arXiv preprint 458 | arXiv:1503.02531. 459 | Hsu, G.-S., Huang, W.-F., and Yap, M. H. (2019). Edgeembedded multi-dropout framework for real-time face 460 | alignment. IEEE Access, 8:6032–6044. 461 | Huang, J., Shao, X., and Wechsler, H. (1998). Face pose 462 | discrimination using support vector machines (svm). 463 | In Proceedings. fourteenth international conference 464 | on pattern recognition (Cat. No. 98EX170), volume 1, 465 | pages 154–156. IEEE. 466 | Jones, M. and Viola, P. (2003). Fast multi-view face detection. Mitsubishi Electric Research Lab TR-20003-96, 467 | 3(14):2. 468 | Kazemi, V. and Sullivan, J. (2014). One millisecond face 469 | alignment with an ensemble of regression trees. In 470 | Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1867–1874. 471 | Kumar, A., Alavi, A., and Chellappa, R. (2017). Kepler: 472 | Keypoint and pose estimation of unconstrained faces 473 | by learning efficient h-cnn regressors. In 2017 12th 474 | ieee international conference on automatic face & 475 | gesture recognition (fg 2017), pages 258–265. IEEE. 476 | Lathuili`ere, S., Juge, R., Mesejo, P., Munoz-Salinas, R., and 477 | Horaud, R. (2017). Deep mixture of linear inverse 478 | regressions applied to head-pose estimation. In Proceedings of the IEEE Conference on Computer Vision 479 | and Pattern Recognition, pages 4817–4825. 480 | Martin, M., Van De Camp, F., and Stiefelhagen, R. (2014). 481 | Real time head model creation and head pose estimation on consumer depth cameras. In 2014 2nd International Conference on 3D Vision, volume 1, pages 482 | 641–648. IEEE. 483 | Meyer, G. P., Gupta, S., Frosio, I., Reddy, D., and Kautz, J. 484 | (2015). Robust model-based 3d head pose estimation. 485 | In Proceedings of the IEEE international conference 486 | on computer vision, pages 3649–3657. 487 | Mukherjee, S. S. and Robertson, N. M. (2015). Deep head 488 | pose: Gaze-direction estimation in multimodal video. 489 | IEEE Transactions on Multimedia, 17(11):2094– 490 | 2107. 491 | Murphy-Chutorian, E., Doshi, A., and Trivedi, M. M. 492 | (2007). Head pose estimation for driver assistance 493 | systems: A robust algorithm and experimental evaluation. In 2007 IEEE intelligent transportation systems 494 | conference, pages 709–714. IEEE. 495 | Murphy-Chutorian, E. and Trivedi, M. M. (2008). Head 496 | pose estimation in computer vision: A survey. IEEE 497 | transactions on pattern analysis and machine intelligence, 31(4):607–626. 498 | Ng, J. and Gong, S. (2002). Composite support vector machines for detection of faces across views and pose estimation. Image and Vision Computing, 20(5-6):359– 499 | 368. 500 | Niyogi, S. and Freeman, W. T. (1996). Example-based head 501 | tracking. In Proceedings of the second international 502 | conference on automatic face and gesture recognition, 503 | pages 374–378. IEEE. 504 | Ranjan, R., Patel, V. M., and Chellappa, R. (2017a). Hyperface: A deep multi-task learning framework for 505 | face detection, landmark localization, pose estimation, 506 | and gender recognition. IEEE transactions on pattern 507 | analysis and machine intelligence, 41(1):121–135. 508 | Ranjan, R., Sankaranarayanan, S., Castillo, C. D., and Chellappa, R. (2017b). An all-in-one convolutional neural network for face analysis. In 2017 12th IEEE In-ternational Conference on Automatic Face & Gesture 509 | Recognition (FG 2017), pages 17–24. IEEE. 510 | Ruiz, N., Chong, E., and Rehg, J. M. (2018). Fine-grained 511 | head pose estimation without keypoints. In Proceedings of the IEEE conference on computer vision and 512 | pattern recognition workshops, pages 2074–2083. 513 | Schwarz, A., Haurilet, M., Martinez, M., and Stiefelhagen, 514 | R. (2017). Driveahead-a large-scale driver head pose 515 | dataset. In Proceedings of the IEEE Conference on 516 | Computer Vision and Pattern Recognition Workshops, 517 | pages 1–10. 518 | Seemann, E., Nickel, K., and Stiefelhagen, R. (2004). Head 519 | pose estimation using stereo vision for human-robot 520 | interaction. In Sixth IEEE International Conference 521 | on Automatic Face and Gesture Recognition, 2004. 522 | Proceedings., pages 626–631. IEEE. 523 | Sherrah, J., Gong, S., and Ong, E.-J. (1999). Understanding pose discrimination in similarity space. In BMVC, 524 | pages 1–10. Citeseer. 525 | Sherrah, J., Gong, S., and Ong, E.-J. (2001). Face distributions in similarity space under varying head pose. 526 | Image and Vision Computing, 19(12):807–819. 527 | Srinivas, A., Lin, T.-Y., Parmar, N., Shlens, J., Abbeel, P., 528 | and Vaswani, A. (2021). Bottleneck transformers for 529 | visual recognition. In Proceedings of the IEEE/CVF 530 | Conference on Computer Vision and Pattern Recognition, pages 16519–16529. 531 | Sun, Y., Wang, X., and Tang, X. (2013). Deep convolutional network cascade for facial point detection. In 532 | Proceedings of the IEEE conference on computer vision and pattern recognition, pages 3476–3483. 533 | Wang, Y., Liang, W., Shen, J., Jia, Y., and Yu, L.-F. (2019). 534 | A deep coarse-to-fine network for head pose estimation from synthetic data. Pattern Recognition, 94:196– 535 | 206. 536 | Xie, S., Girshick, R., Doll ́ar, P., Tu, Z., and He, K. (2017). 537 | Aggregated residual transformations for deep neural 538 | networks. In Proceedings of the IEEE conference on 539 | computer vision and pattern recognition, pages 1492– 540 | 1500. 541 | Xin, M., Mo, S., and Lin, Y. (2021). Eva-gcn: Head pose 542 | estimation based on graph convolutional networks. In 543 | Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 1462– 544 | 1471. 545 | Xiong, X. and De la Torre, F. (2015). Global supervised descent method. In Proceedings of the IEEE Conference 546 | on Computer Vision and Pattern Recognition, pages 547 | 2664–2673. 548 | Yan, S., Xiong, Y., and Lin, D. (2018). Spatial temporal 549 | graph convolutional networks for skeleton-based action recognition. In Thirty-second AAAI conference 550 | on artificial intelligence. 551 | Yang, T.-Y., Chen, Y.-T., Lin, Y.-Y., and Chuang, Y.-Y. 552 | (2019). Fsa-net: Learning fine-grained structure aggregation for head pose estimation from a single image. In Proceedings of the IEEE/CVF Conference 553 | on Computer Vision and Pattern Recognition, pages 554 | 1087–1096. 555 | Yang, T.-Y., Huang, Y.-H., Lin, Y.-Y., Hsiu, P.-C., and 556 | Chuang, Y.-Y. (2018). Ssr-net: A compact soft stagewise regression network for age estimation. In IJCAI, 557 | volume 5, page 7. 558 | Zhang, H., Wu, C., Zhang, Z., Zhu, Y., Lin, H., Zhang, 559 | Z., Sun, Y., He, T., Mueller, J., Manmatha, R., et al. 560 | (2020). Resnest: Split-attention networks. arXiv 561 | preprint arXiv:2004.08955. 562 | Zhang, Z., Hu, Y., Liu, M., and Huang, T. (2006). Head 563 | pose estimation in seminar room using multi view face 564 | detectors. In International evaluation workshop on 565 | classification of events, activities and relationships, 566 | pages 299–304. Springer. 567 | Zhou, Y. and Gregson, J. (2020). Whenet: Real-time finegrained estimation for wide range head pose. arXiv 568 | preprint arXiv:2005.10353. 569 | Zhu, X., Lei, Z., Liu, X., Shi, H., and Li, S. Z. (2016). Face 570 | alignment across large poses: A 3d solution. In Proceedings of the IEEE conference on computer vision 571 | and pattern recognition, pages 146–155. 572 | Zhu, X. and Ramanan, D. (2012). Face detection, pose estimation, and landmark localization in the wild. In 573 | 2012 IEEE conference on computer vision and pattern 574 | recognition, pages 2879–2886. IEEE. -------------------------------------------------------------------------------- /reference/G2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/reference/G2d.png -------------------------------------------------------------------------------- /reference/flowfields.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/reference/flowfields.png -------------------------------------------------------------------------------- /reference/google_scholar_profile_results_data.json: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /reference/megaportrait-network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/reference/megaportrait-network.png -------------------------------------------------------------------------------- /reference/megaportrait-student.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/reference/megaportrait-student.png -------------------------------------------------------------------------------- /reference/references.txt: -------------------------------------------------------------------------------- 1 | REFERENCES 2 | [1] Volker Blanz and Thomas Vetter. 1999. A morphable model for the synthesis of 3 | 3D faces. In SIGGRAPH ’99. 4 | [2] Adrian Bulat and Georgios Tzimiropoulos. 2017. How Far are We from Solving 5 | the 2D & 3D Face Alignment Problem? (and a Dataset of 230,000 3D Facial 6 | Landmarks). 2017 IEEE International Conference on Computer Vision (ICCV) (2017), 7 | 1021–1030. 8 | [3] Egor Burkov, I. Pasechnik, Artur Grigorev, and Victor S. Lempitsky. 2020. Neural 9 | Head Reenactment with Latent Pose Descriptors. 2020 IEEE/CVF Conference on 10 | Computer Vision and Pattern Recognition (CVPR) (2020), 13783–13792. 11 | [4] Joon Son Chung, Arsha Nagrani, and Andrew Zisserman. 2018. VoxCeleb2: Deep 12 | Speaker Recognition. In INTERSPEECH. 13 | [5] Kevin Cortacero, Tobias Fischer, and Yiannis Demiris. 2019. RT-BENE: A Dataset 14 | and Baselines for Real-Time Blink Estimation in Natural Environments. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) 15 | Workshops. 16 | [6] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, K. Li, and Li Fei-Fei. 2009. ImageNet: 17 | A large-scale hierarchical image database. In CVPR. 18 | [7] Jiankang Deng, J. Guo, Evangelos Ververas, Irene Kotsia, Stefanos Zafeiriou, and 19 | InsightFace FaceSoft. 2020. RetinaFace: Single-Shot Multi-Level Face Localisation 20 | in the Wild. 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition 21 | (CVPR) (2020), 5202–5211. 22 | [8] Michail Christos Doukas, Stefanos Zafeiriou, and Viktoriia Sharmanska. 2021. 23 | HeadGAN: One-shot Neural Head Synthesis and Editing. 2021 IEEE/CVF International Conference on Computer Vision (ICCV). 24 | [9] Tobias Fischer, Hyung Jin Chang, and Y. Demiris. 2018. RT-GENE: Real-Time 25 | Eye Gaze Estimation in Natural Environments. In ECCV. 26 | [10] Guy Gafni, Justus Thies, Michael Zollhofer, and Matthias Nießner. 2021. Dynamic 27 | Neural Radiance Fields for Monocular 4D Facial Avatar Reconstruction. 2021 28 | IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 8645– 29 | 8654. 30 | [11] Ke Gong, Yiming Gao, Xiaodan Liang, Xiaohui Shen, M. Wang, and Liang Lin. 31 | 2019. Graphonomy: Universal Human Parsing via Graph Transfer Learning. 2019 32 | IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (2019), 33 | 7442–7451. 34 | [12] Sungjoo Ha, Martin Kersner, Beomsu Kim, Seokjun Seo, and Dongyoung Kim. 35 | 2020. MarioNETte: Few-shot Face Reenactment Preserving Identity of Unseen 36 | Targets. In AAAI. 37 | [13] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and 38 | Sepp Hochreiter. 2017. GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium. In Advances in Neural Information Processing 39 | Systems. 40 | [14] Justin Johnson, Alexandre Alahi, and Li Fei-Fei. 2016. Perceptual Losses for 41 | Real-Time Style Transfer and Super-Resolution. In ECCV. 42 | [15] Tero Karras, Samuli Laine, and Timo Aila. 2019. A Style-Based Generator Architecture for Generative Adversarial Networks. 2019 IEEE/CVF Conference on 43 | Computer Vision and Pattern Recognition (CVPR), 4396–4405. 44 | [16] Zhanghan Ke, Jiayu Sun, Kaican Li, Qiong Yan, and Rynson W.H. Lau. 2022. 45 | MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition. 46 | In AAAI. 47 | [17] Hyeongwoo Kim, Pablo Garrido, Ayush Tewari, Weipeng Xu, Justus Thies, 48 | Matthias Nießner, Patrick Pérez, Christian Richardt, Michael Zollhöfer, and Christian Theobalt. 2018. Deep video portraits. ACM Transactions on Graphics (TOG) 49 | 37 (2018), 1 – 14. 50 | [18] Stephen Lombardi, Jason M. Saragih, Tomas Simon, and Yaser Sheikh. 2018. Deep 51 | appearance models for face rendering. ACM Transactions on Graphics (TOG) 37 52 | (2018), 1 – 13. 53 | [19] Stephen Lombardi, Tomas Simon, Jason M. Saragih, Gabriel Schwartz, Andreas M. 54 | Lehrmann, and Yaser Sheikh. 2019. Neural volumes. ACM Transactions on 55 | Graphics (TOG) 38 (2019), 1 – 14. 56 | [20] Ilya Loshchilov and Frank Hutter. 2019. Decoupled Weight Decay Regularization. 57 | In ICLR. 58 | [21] Ben Mildenhall, Pratul P. Srinivasan, Matthew Tancik, Jonathan T. Barron, Ravi 59 | Ramamoorthi, and Ren Ng. 2020. NeRF: Representing Scenes as Neural Radiance 60 | Fields for View Synthesis. In ECCV. 61 | [22] Keunhong Park, U. Sinha, Jonathan T. Barron, Sofien Bouaziz, Dan B. Goldman, 62 | Steven M. Seitz, and Ricardo Martin-Brualla. 2021. Nerfies: Deformable Neural 63 | Radiance Fields. 2021 IEEE/CVF International Conference on Computer Vision 64 | (ICCV). 65 | [23] Keunhong Park, U. Sinha, Peter Hedman, Jonathan T. Barron, Sofien Bouaziz, 66 | Dan B. Goldman, Ricardo Martin-Brualla, and Steven M. Seitz. 2021. HyperNeRF: 67 | A Higher-Dimensional Representation for Topologically Varying Neural Radiance 68 | Fields. ArXiv. 69 | [24] Omkar M. Parkhi, Andrea Vedaldi, and Andrew Zisserman. 2015. Deep Face 70 | Recognition. In BMVC. 71 | [25] Aliaksandr Siarohin, Stéphane Lathuilière, S. Tulyakov, Elisa Ricci, and N. Sebe. 72 | 2019. Animating Arbitrary Objects via Deep Motion Transfer. 2019 IEEE/CVF 73 | Conference on Computer Vision and Pattern Recognition (CVPR) (2019), 2372–2381. 74 | [26] Aliaksandr Siarohin, Stéphane Lathuilière, S. Tulyakov, Elisa Ricci, and N. Sebe. 75 | 2019. First Order Motion Model for Image Animation. ArXiv abs/2003.00196 76 | (2019). 77 | [27] Karen Simonyan and Andrew Zisserman. 2015. Very Deep Convolutional Networks for Large-Scale Image Recognition. CoRR abs/1409.1556 (2015). 78 | [28] Shaolin Su, Qingsen Yan, Yu Zhu, Cheng Zhang, Xin Ge, Jinqiu Sun, and Yanning 79 | Zhang. 2020. Blindly Assess Image Quality in the Wild Guided by a Self-Adaptive 80 | Hyper Network. In IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). 81 | [29] Roman Suvorov, Elizaveta Logacheva, Anton Mashikhin, Anastasia Remizova, 82 | Arsenii Ashukha, Aleksei Silvestrov, Naejin Kong, Harshith Goka, Kiwoong Park, 83 | and Victor Lempitsky. 2021. Resolution-robust Large Mask Inpainting with 84 | Fourier Convolutions. arXiv preprint arXiv:2109.07161 (2021). 85 | [30] Justus Thies, Michael Zollhöfer, Marc Stamminger, Christian Theobalt, and 86 | Matthias Nießner. 2019. Face2Face: real-time face capture and reenactment 87 | of RGB videos. ArXiv abs/2007.14808 (2019). 88 | [31] H. Wang, Yitong Wang, Zheng Zhou, Xing Ji, Zhifeng Li, Dihong Gong, Jin 89 | Zhou, and Wenyu Liu. 2018. CosFace: Large Margin Cosine Loss for Deep 90 | Face Recognition. 2018 IEEE/CVF Conference on Computer Vision and Pattern 91 | Recognition (2018), 5265–5274. 92 | [32] Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu, Guilin Liu, Andrew Tao, Jan Kautz, 93 | and Bryan Catanzaro. 2018. Video-to-Video Synthesis. In Advances in Neural 94 | Information Processing Systems (NeurIPS). 95 | [33] Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu, Andrew Tao, Jan Kautz, and Bryan 96 | Catanzaro. 2018. High-Resolution Image Synthesis and Semantic Manipulation 97 | with Conditional GANs. In Proceedings of the IEEE Conference on Computer Vision 98 | and Pattern Recognition. 99 | [34] Ting-Chun Wang, Arun Mallya, and Ming-Yu Liu. 2021. One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing. 2021 IEEE/CVF Conference 100 | on Computer Vision and Pattern Recognition (CVPR).MM ’22, October 10–14, 2022, Lisboa, Portugal Nikita Drobyshev et al. 101 | [35] Zhou Wang, Alan Conrad Bovik, Hamid R. Sheikh, and Eero P. Simoncelli. 2004. 102 | Image quality assessment: from error visibility to structural similarity. IEEE 103 | Transactions on Image Processing 13 (2004), 600–612. 104 | [36] Gengshan Yang, Minh Vo, Natalia Neverova, Deva Ramanan, Andrea Vedaldi, 105 | and Hanbyul Joo. 2021. BANMo: Building Animatable 3D Neural Models from 106 | Many Casual Videos. ArXiv. 107 | [37] Lingbo Yang, C. Liu, P. Wang, Shanshe Wang, P. Ren, Siwei Ma, and W. Gao. 2020. 108 | HiFaceGAN: Face Renovation via Collaborative Suppression and Replenishment. 109 | Proceedings of the 28th ACM International Conference on Multimedia (2020). 110 | [38] Egor Zakharov, Aleksei Ivakhnenko, Aliaksandra Shysheya, and Victor S. Lempitsky. 2020. Fast Bi-layer Neural Synthesis of One-Shot Realistic Head Avatars. 111 | In ECCV. 112 | [39] Egor Zakharov, Aliaksandra Shysheya, Egor Burkov, and Victor S. Lempitsky. 113 | 2019. Few-Shot Adversarial Learning of Realistic Neural Talking Head Models. 114 | 2019 IEEE/CVF International Conference on Computer Vision (ICCV). 115 | [40] Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, and Oliver Wang. 116 | 2018. The Unreasonable Effectiveness of Deep Features as a Perceptual Metric. 117 | 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition (2018), 118 | 586–595. 119 | [41] Shifeng Zhang, Xiangyu Zhu, Zhen Lei, Hailin Shi, Xiaobo Wang, and S. Li. 120 | 2017. S3FD: Single Shot Scale-Invariant Face Detector. 2017 IEEE International 121 | Conference on Computer Vision (ICCV) (2017), 192–201. 122 | [42] Jun-Yan Zhu, Taesung Park, Phillip Isola, and Alexei A. Efros. 2017. Unpaired 123 | Image-to-Image Translation Using Cycle-Consistent Adversarial Networks. 2017 124 | IEEE International Conference on Computer Vision (ICCV) (2017), 2242–2251. -------------------------------------------------------------------------------- /reference/resnetblocks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/reference/resnetblocks.png -------------------------------------------------------------------------------- /reference/test.py: -------------------------------------------------------------------------------- 1 | from google_scholar_py import CustomGoogleScholarProfiles 2 | import json 3 | 4 | parser = CustomGoogleScholarProfiles() 5 | data = parser.scrape_google_scholar_profiles( 6 | query='blizzard', 7 | pagination=True, 8 | save_to_csv=False, 9 | save_to_json=True 10 | ) 11 | print(json.dumps(data, indent=2)) 12 | -------------------------------------------------------------------------------- /reference/warpfield.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/reference/warpfield.png -------------------------------------------------------------------------------- /reference/warpgenerators/warpgenerators.txt: -------------------------------------------------------------------------------- 1 | Neural Head Reenactment with Latent Pose Descriptors (Egor Burkov et al., 2020) - Reference [3] 2 | HeadGAN: One-shot Neural Head Synthesis and Editing (Michail Christos Doukas et al., 2021) - Reference [8] 3 | Dynamic Neural Radiance Fields for Monocular 4D Facial Avatar Reconstruction (Guy Gafni et al., 2021) - Reference [10] 4 | MarioNETte: Few-shot Face Reenactment Preserving Identity of Unseen Targets (Sungjoo Ha et al., 2020) - Reference [12] 5 | Deep video portraits (Hyeongwoo Kim et al., 2018) - Reference [17] 6 | Face2Face: real-time face capture and reenactment of RGB videos (Justus Thies et al., 2019) - Reference [30] 7 | -------------------------------------------------------------------------------- /reference/x3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/reference/x3.png -------------------------------------------------------------------------------- /reference/x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/MegaPortrait-hack/580cab3cadc1873dad5e7b9af0a0d9e882dee486/reference/x4.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # conda install pytorch3d -c pytorch3d 2 | 3 | pytorch-msssim 4 | chardet 5 | lpips 6 | mediapipe 7 | colored-traceback 8 | pygments 9 | git+https://github.com/johndpope/colored-traceback.py.git 10 | torchsummary 11 | moviepy 12 | decord 13 | # eva-decord # mac m1 support - is it broken? 14 | omegaconf 15 | memory_profiler 16 | face_recognition 17 | facenet_pytorch 18 | rembg 19 | lpips 20 | tensorboard 21 | torch-fidelity -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one 2 | # or more contributor license agreements. See the NOTICE file 3 | # distributed with this work for additional information 4 | # regarding copyright ownership. The ASF licenses this file 5 | # to you under the Apache License, Version 2.0 (the 6 | # "License"); you may not use this file except in compliance 7 | # with the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, 12 | # software distributed under the License is distributed on an 13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | # KIND, either express or implied. See the License for the 15 | # specific language governing permissions and limitations 16 | # under the License. 17 | 18 | # build resnet for cifar10, debug use only 19 | # from https://github.com/huyvnphan/PyTorch_CIFAR10/blob/master/cifar10_models/resnet.py 20 | 21 | import os 22 | import requests 23 | from tqdm import tqdm 24 | import zipfile 25 | import torch.utils.model_zoo as modelzoo 26 | import torch.nn.functional as F 27 | import torch 28 | import torch.nn as nn 29 | 30 | __all__ = [ 31 | "ResNet", 32 | "resnet18", 33 | "resnet34", 34 | "resnet50", 35 | ] 36 | weights_downloaded = False 37 | 38 | 39 | 40 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 41 | """3x3 convolution with padding""" 42 | return nn.Conv2d( 43 | in_planes, 44 | out_planes, 45 | kernel_size=3, 46 | stride=stride, 47 | padding=dilation, 48 | groups=groups, 49 | bias=False, 50 | dilation=dilation, 51 | ) 52 | 53 | 54 | def conv1x1(in_planes, out_planes, stride=1): 55 | """1x1 convolution""" 56 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | expansion = 1 61 | 62 | def __init__( 63 | self, 64 | inplanes, 65 | planes, 66 | stride=1, 67 | downsample=None, 68 | groups=1, 69 | base_width=64, 70 | dilation=1, 71 | norm_layer=None, 72 | ): 73 | super(BasicBlock, self).__init__() 74 | if norm_layer is None: 75 | norm_layer = nn.BatchNorm2d 76 | if groups != 1 or base_width != 64: 77 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 78 | if dilation > 1: 79 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 80 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 81 | self.conv1 = conv3x3(inplanes, planes, stride) 82 | self.bn1 = norm_layer(planes) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.conv2 = conv3x3(planes, planes) 85 | self.bn2 = norm_layer(planes) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | identity = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | 99 | if self.downsample is not None: 100 | identity = self.downsample(x) 101 | 102 | out += identity 103 | out = self.relu(out) 104 | 105 | return out 106 | 107 | 108 | class Bottleneck(nn.Module): 109 | expansion = 4 110 | 111 | def __init__( 112 | self, 113 | inplanes, 114 | planes, 115 | stride=1, 116 | downsample=None, 117 | groups=1, 118 | base_width=64, 119 | dilation=1, 120 | norm_layer=None, 121 | ): 122 | super(Bottleneck, self).__init__() 123 | if norm_layer is None: 124 | norm_layer = nn.BatchNorm2d 125 | width = int(planes * (base_width / 64.0)) * groups 126 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 127 | self.conv1 = conv1x1(inplanes, width) 128 | self.bn1 = norm_layer(width) 129 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 130 | self.bn2 = norm_layer(width) 131 | self.conv3 = conv1x1(width, planes * self.expansion) 132 | self.bn3 = norm_layer(planes * self.expansion) 133 | self.relu = nn.ReLU(inplace=True) 134 | self.downsample = downsample 135 | self.stride = stride 136 | 137 | def forward(self, x): 138 | identity = x 139 | 140 | out = self.conv1(x) 141 | out = self.bn1(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv2(out) 145 | out = self.bn2(out) 146 | out = self.relu(out) 147 | 148 | out = self.conv3(out) 149 | out = self.bn3(out) 150 | 151 | if self.downsample is not None: 152 | identity = self.downsample(x) 153 | 154 | out += identity 155 | out = self.relu(out) 156 | 157 | return out 158 | 159 | 160 | class ResNet(nn.Module): 161 | def __init__( 162 | self, 163 | block, 164 | layers, 165 | num_classes=10, 166 | zero_init_residual=False, 167 | groups=1, 168 | width_per_group=64, 169 | replace_stride_with_dilation=None, 170 | norm_layer=None, 171 | ): 172 | super(ResNet, self).__init__() 173 | if norm_layer is None: 174 | norm_layer = nn.BatchNorm2d 175 | self._norm_layer = norm_layer 176 | 177 | self.inplanes = 64 178 | self.dilation = 1 179 | if replace_stride_with_dilation is None: 180 | # each element in the tuple indicates if we should replace 181 | # the 2x2 stride with a dilated convolution instead 182 | replace_stride_with_dilation = [False, False, False] 183 | if len(replace_stride_with_dilation) != 3: 184 | raise ValueError( 185 | "replace_stride_with_dilation should be None " 186 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 187 | ) 188 | self.groups = groups 189 | self.base_width = width_per_group 190 | 191 | # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 192 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 193 | # END 194 | 195 | self.bn1 = norm_layer(self.inplanes) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 198 | self.layer1 = self._make_layer(block, 64, layers[0]) 199 | self.layer2 = self._make_layer( 200 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 201 | ) 202 | self.layer3 = self._make_layer( 203 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 204 | ) 205 | self.layer4 = self._make_layer( 206 | block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 207 | ) 208 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 209 | self.fc = nn.Linear(512 * block.expansion, num_classes) 210 | 211 | for m in self.modules(): 212 | if isinstance(m, nn.Conv2d): 213 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 214 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 215 | nn.init.constant_(m.weight, 1) 216 | nn.init.constant_(m.bias, 0) 217 | 218 | # Zero-initialize the last BN in each residual branch, 219 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 220 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 221 | if zero_init_residual: 222 | for m in self.modules(): 223 | if isinstance(m, Bottleneck): 224 | nn.init.constant_(m.bn3.weight, 0) 225 | elif isinstance(m, BasicBlock): 226 | nn.init.constant_(m.bn2.weight, 0) 227 | 228 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 229 | norm_layer = self._norm_layer 230 | downsample = None 231 | previous_dilation = self.dilation 232 | if dilate: 233 | self.dilation *= stride 234 | stride = 1 235 | if stride != 1 or self.inplanes != planes * block.expansion: 236 | downsample = nn.Sequential( 237 | conv1x1(self.inplanes, planes * block.expansion, stride), 238 | norm_layer(planes * block.expansion), 239 | ) 240 | 241 | layers = [] 242 | layers.append( 243 | block( 244 | self.inplanes, 245 | planes, 246 | stride, 247 | downsample, 248 | self.groups, 249 | self.base_width, 250 | previous_dilation, 251 | norm_layer, 252 | ) 253 | ) 254 | self.inplanes = planes * block.expansion 255 | for _ in range(1, blocks): 256 | layers.append( 257 | block( 258 | self.inplanes, 259 | planes, 260 | groups=self.groups, 261 | base_width=self.base_width, 262 | dilation=self.dilation, 263 | norm_layer=norm_layer, 264 | ) 265 | ) 266 | 267 | return nn.Sequential(*layers) 268 | 269 | def forward(self, x): 270 | x = self.conv1(x) 271 | x = self.bn1(x) 272 | x = self.relu(x) 273 | x = self.maxpool(x) 274 | 275 | x = self.layer1(x) 276 | x = self.layer2(x) 277 | x = self.layer3(x) 278 | x = self.layer4(x) 279 | 280 | x = self.avgpool(x) 281 | x = x.reshape(x.size(0), -1) 282 | x = self.fc(x) 283 | 284 | return x 285 | 286 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 287 | global weights_downloaded 288 | model = ResNet(block, layers, **kwargs) 289 | if pretrained: 290 | if not weights_downloaded: 291 | download_weights() 292 | weights_downloaded = True 293 | 294 | script_dir = os.path.dirname(__file__) 295 | state_dict_path = os.path.join(script_dir, "cifar10_models/state_dicts", arch + ".pt") 296 | if os.path.isfile(state_dict_path): 297 | state_dict = torch.load(state_dict_path, map_location=device) 298 | model.load_state_dict(state_dict) 299 | else: 300 | raise FileNotFoundError(f"No such file or directory: '{state_dict_path}'") 301 | return model 302 | 303 | 304 | def resnet18(pretrained=False, progress=True, device="cpu", **kwargs): 305 | """Constructs a ResNet-18 model. 306 | Args: 307 | pretrained (bool): If True, returns a model pre-trained on ImageNet 308 | progress (bool): If True, displays a progress bar of the download to stderr 309 | """ 310 | return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs) 311 | 312 | 313 | def resnet34(pretrained=False, progress=True, device="cpu", **kwargs): 314 | """Constructs a ResNet-34 model. 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | progress (bool): If True, displays a progress bar of the download to stderr 318 | """ 319 | return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, **kwargs) 320 | 321 | 322 | def resnet50(pretrained=False, progress=True, device="cpu", **kwargs): 323 | """Constructs a ResNet-50 model. 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | """ 328 | return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, **kwargs) 329 | 330 | 331 | def download_weights(): 332 | 333 | script_dir = os.path.dirname(__file__) 334 | state_dicts_dir = os.path.join(script_dir, "cifar10_models") 335 | 336 | if os.path.isdir(state_dicts_dir) and len(os.listdir(state_dicts_dir)) > 0: 337 | print("Weights already downloaded. Skipping download.") 338 | return 339 | 340 | url = "https://rutgers.box.com/shared/static/gkw08ecs797j2et1ksmbg1w5t3idf5r5.zip" 341 | 342 | # Streaming, so we can iterate over the response. 343 | r = requests.get(url, stream=True) 344 | 345 | # Total size in Mebibyte 346 | total_size = int(r.headers.get("content-length", 0)) 347 | block_size = 2**20 # Mebibyte 348 | t = tqdm(total=total_size, unit="MiB", unit_scale=True) 349 | 350 | with open("state_dicts.zip", "wb") as f: 351 | for data in r.iter_content(block_size): 352 | t.update(len(data)) 353 | f.write(data) 354 | t.close() 355 | 356 | if total_size != 0 and t.n != total_size: 357 | raise Exception("Error, something went wrong") 358 | 359 | print("Download successful. Unzipping file...") 360 | path_to_zip_file = os.path.join(os.getcwd(), "state_dicts.zip") 361 | directory_to_extract_to = os.path.join(os.getcwd(), "cifar10_models") 362 | 363 | with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: 364 | zip_ref.extractall(directory_to_extract_to) 365 | print("Unzip file successful!") 366 | 367 | 368 | # original resblock 369 | class ResBlock2D(nn.Module): 370 | def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15): 371 | super(ResBlock2D, self).__init__() 372 | padding = self._get_same_padding(kernel, dilation) 373 | 374 | layer_s = list() 375 | layer_s.append(nn.Conv2d(n_c, n_c, kernel, padding=padding, dilation=dilation, bias=False)) 376 | layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6)) 377 | layer_s.append(nn.ELU(inplace=True)) 378 | # dropout 379 | layer_s.append(nn.Dropout(p_drop)) 380 | # convolution 381 | layer_s.append(nn.Conv2d(n_c, n_c, kernel, dilation=dilation, padding=padding, bias=False)) 382 | layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6)) 383 | self.layer = nn.Sequential(*layer_s) 384 | self.final_activation = nn.ELU(inplace=True) 385 | 386 | def _get_same_padding(self, kernel, dilation): 387 | return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2 388 | 389 | def forward(self, x): 390 | out = self.layer(x) 391 | return self.final_activation(x + out) 392 | 393 | 394 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 395 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 396 | for i in range(bnum-1): 397 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 398 | return nn.Sequential(*layers) 399 | 400 | 401 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 402 | class ResNet18(nn.Module): 403 | def __init__(self): 404 | super(ResNet18, self).__init__() 405 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 406 | bias=False) 407 | self.bn1 = nn.BatchNorm2d(64) 408 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 409 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 410 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 411 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 412 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 413 | self.init_weight() 414 | 415 | def forward(self, x): 416 | x = self.conv1(x) 417 | x = F.relu(self.bn1(x)) 418 | x = self.maxpool(x) 419 | 420 | x = self.layer1(x) 421 | feat8 = self.layer2(x) # 1/8 422 | feat16 = self.layer3(feat8) # 1/16 423 | feat32 = self.layer4(feat16) # 1/32 424 | return feat8, feat16, feat32 425 | 426 | def init_weight(self): 427 | state_dict = modelzoo.load_url(resnet18_url) 428 | # state_dict = torch.load('/apdcephfs/share_1290939/kevinyxpang/STIT/resnet18-5c106cde.pth') 429 | self_state_dict = self.state_dict() 430 | for k, v in state_dict.items(): 431 | if 'fc' in k: continue 432 | self_state_dict.update({k: v}) 433 | self.load_state_dict(self_state_dict) 434 | 435 | def get_params(self): 436 | wd_params, nowd_params = [], [] 437 | for name, module in self.named_modules(): 438 | if isinstance(module, (nn.Linear, nn.Conv2d)): 439 | wd_params.append(module.weight) 440 | if not module.bias is None: 441 | nowd_params.append(module.bias) 442 | elif isinstance(module, nn.BatchNorm2d): 443 | nowd_params += list(module.parameters()) 444 | return wd_params, nowd_params 445 | 446 | 447 | 448 | -------------------------------------------------------------------------------- /resnet50.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | 4 | """ 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | from typing import Type, Any, Callable, Union, List, Optional 9 | import torchvision.models as models 10 | 11 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=dilation, groups=groups, bias=False, dilation=dilation) 15 | 16 | 17 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 18 | """1x1 convolution""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion: int = 1 24 | 25 | def __init__( 26 | self, 27 | inplanes: int, 28 | planes: int, 29 | stride: int = 1, 30 | downsample: Optional[nn.Module] = None, 31 | groups: int = 1, 32 | base_width: int = 64, 33 | dilation: int = 1, 34 | norm_layer: Optional[Callable[..., nn.Module]] = None 35 | ) -> None: 36 | super(BasicBlock, self).__init__() 37 | if norm_layer is None: 38 | norm_layer = nn.BatchNorm2d 39 | if groups != 1 or base_width != 64: 40 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 41 | if dilation > 1: 42 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 43 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 44 | self.conv1 = conv3x3(inplanes, planes, stride) 45 | self.bn1 = norm_layer(planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv2 = conv3x3(planes, planes) 48 | self.bn2 = norm_layer(planes) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x: Tensor) -> Tensor: 53 | identity = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | 62 | if self.downsample is not None: 63 | identity = self.downsample(x) 64 | 65 | out += identity 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 73 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 74 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 75 | # This variant is also known as ResNet V1.5 and improves accuracy according to 76 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 77 | 78 | expansion: int = 4 79 | 80 | def __init__( 81 | self, 82 | inplanes: int, 83 | planes: int, 84 | stride: int = 1, 85 | downsample: Optional[nn.Module] = None, 86 | groups: int = 1, 87 | base_width: int = 64, 88 | dilation: int = 1, 89 | norm_layer: Optional[Callable[..., nn.Module]] = None 90 | ) -> None: 91 | super(Bottleneck, self).__init__() 92 | if norm_layer is None: 93 | norm_layer = nn.BatchNorm2d 94 | width = int(planes * (base_width / 64.)) * groups 95 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 96 | self.conv1 = conv1x1(inplanes, width) 97 | self.bn1 = norm_layer(width) 98 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 99 | self.bn2 = norm_layer(width) 100 | self.conv3 = conv1x1(width, planes * self.expansion) 101 | self.bn3 = norm_layer(planes * self.expansion) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x: Tensor) -> Tensor: 107 | identity = x 108 | 109 | out = self.conv1(x) 110 | out = self.bn1(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv2(out) 114 | out = self.bn2(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv3(out) 118 | out = self.bn3(out) 119 | 120 | if self.downsample is not None: 121 | identity = self.downsample(x) 122 | 123 | out += identity 124 | out = self.relu(out) 125 | 126 | return out 127 | 128 | 129 | class ResNet50(nn.Module): 130 | 131 | def __init__( 132 | self, 133 | block: Type[Union[BasicBlock, Bottleneck]] = Bottleneck, 134 | layers: List[int] = [3, 4, 6, 3], 135 | n_class: int = 1000, 136 | zero_init_residual: bool = False, 137 | groups: int = 1, 138 | width_per_group: int = 64, 139 | replace_stride_with_dilation: Optional[List[bool]] = None, 140 | norm_layer: Optional[Callable[..., nn.Module]] = None, 141 | is_remix=False 142 | ) -> None: 143 | super(ResNet50, self).__init__() 144 | if norm_layer is None: 145 | norm_layer = nn.BatchNorm2d 146 | self._norm_layer = norm_layer 147 | 148 | self.inplanes = 64 149 | self.dilation = 1 150 | if replace_stride_with_dilation is None: 151 | # each element in the tuple indicates if we should replace 152 | # the 2x2 stride with a dilated convolution instead 153 | replace_stride_with_dilation = [False, False, False] 154 | if len(replace_stride_with_dilation) != 3: 155 | raise ValueError("replace_stride_with_dilation should be None " 156 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 157 | self.groups = groups 158 | self.base_width = width_per_group 159 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 160 | bias=False) 161 | self.bn1 = norm_layer(self.inplanes) 162 | self.relu = nn.ReLU(inplace=True) 163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 164 | self.layer1 = self._make_layer(block, 64, layers[0]) 165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 166 | dilate=replace_stride_with_dilation[0]) 167 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 168 | dilate=replace_stride_with_dilation[1]) 169 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 170 | dilate=replace_stride_with_dilation[2]) 171 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 172 | # self.fc = nn.Linear(512 * block.expansion, n_class) 173 | self.fc = nn.Linear(512 * block.expansion, 512) # Reduce to 512 dimensions 174 | 175 | # rot_classifier for Remix Match 176 | self.is_remix = is_remix 177 | if is_remix: 178 | self.rot_classifier = nn.Linear(2048, 4) 179 | 180 | for m in self.modules(): 181 | if isinstance(m, nn.Conv2d): 182 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 183 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 184 | nn.init.constant_(m.weight, 1) 185 | nn.init.constant_(m.bias, 0) 186 | 187 | # Zero-initialize the last BN in each residual branch, 188 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 189 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 190 | if zero_init_residual: 191 | for m in self.modules(): 192 | if isinstance(m, Bottleneck): 193 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 194 | elif isinstance(m, BasicBlock): 195 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 196 | 197 | self._initialize_weights() 198 | 199 | def _initialize_weights(self): 200 | pretrained_resnet50 = models.resnet50(pretrained=True) 201 | 202 | self.conv1.weight.data = pretrained_resnet50.conv1.weight.data.clone() 203 | self.bn1.weight.data = pretrained_resnet50.bn1.weight.data.clone() 204 | self.bn1.bias.data = pretrained_resnet50.bn1.bias.data.clone() 205 | 206 | for i in range(1, 5): 207 | layer = getattr(self, f'layer{i}') 208 | pretrained_layer = getattr(pretrained_resnet50, f'layer{i}') 209 | self._initialize_layer_weights(layer, pretrained_layer) 210 | 211 | # Comment out the following lines if you don't want to copy the FC layer weights 212 | # self.fc.weight.data = pretrained_resnet50.fc.weight.data.clone() 213 | # self.fc.bias.data = pretrained_resnet50.fc.bias.data.clone() 214 | 215 | def _initialize_layer_weights(self, layer, pretrained_layer): 216 | for block, pretrained_block in zip(layer, pretrained_layer): 217 | block.conv1.weight.data = pretrained_block.conv1.weight.data.clone() 218 | block.bn1.weight.data = pretrained_block.bn1.weight.data.clone() 219 | block.bn1.bias.data = pretrained_block.bn1.bias.data.clone() 220 | block.conv2.weight.data = pretrained_block.conv2.weight.data.clone() 221 | block.bn2.weight.data = pretrained_block.bn2.weight.data.clone() 222 | block.bn2.bias.data = pretrained_block.bn2.bias.data.clone() 223 | if isinstance(block, Bottleneck): 224 | block.conv3.weight.data = pretrained_block.conv3.weight.data.clone() 225 | block.bn3.weight.data = pretrained_block.bn3.weight.data.clone() 226 | block.bn3.bias.data = pretrained_block.bn3.bias.data.clone() 227 | if block.downsample is not None: 228 | block.downsample[0].weight.data = pretrained_block.downsample[0].weight.data.clone() 229 | block.downsample[1].weight.data = pretrained_block.downsample[1].weight.data.clone() 230 | block.downsample[1].bias.data = pretrained_block.downsample[1].bias.data.clone() 231 | 232 | 233 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 234 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 235 | norm_layer = self._norm_layer 236 | downsample = None 237 | previous_dilation = self.dilation 238 | if dilate: 239 | self.dilation *= stride 240 | stride = 1 241 | if stride != 1 or self.inplanes != planes * block.expansion: 242 | downsample = nn.Sequential( 243 | conv1x1(self.inplanes, planes * block.expansion, stride), 244 | norm_layer(planes * block.expansion), 245 | ) 246 | 247 | layers = [] 248 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 249 | self.base_width, previous_dilation, norm_layer)) 250 | self.inplanes = planes * block.expansion 251 | for _ in range(1, blocks): 252 | layers.append(block(self.inplanes, planes, groups=self.groups, 253 | base_width=self.base_width, dilation=self.dilation, 254 | norm_layer=norm_layer)) 255 | 256 | return nn.Sequential(*layers) 257 | 258 | def _forward_impl(self, x): 259 | # See note [TorchScript super()] 260 | x = self.conv1(x) 261 | x = self.bn1(x) 262 | x = self.relu(x) 263 | x = self.maxpool(x) 264 | 265 | x = self.layer1(x) 266 | x = self.layer2(x) 267 | x = self.layer3(x) 268 | x = self.layer4(x) 269 | 270 | x = self.avgpool(x) 271 | x = torch.flatten(x, 1) 272 | x = self.fc(x) # Reduce to 512 dimensions 273 | # out = self.fc(x) # Comment out this line if you don't want to use the FC layer 274 | if self.is_remix: 275 | rot_output = self.rot_classifier(x) 276 | return x, rot_output 277 | else: 278 | return x 279 | 280 | def forward(self, x): 281 | return self._forward_impl(x) 282 | 283 | 284 | class build_ResNet50: 285 | def __init__(self, is_remix=False): 286 | self.is_remix = is_remix 287 | 288 | def build(self, num_classes): 289 | return ResNet50(n_class=num_classes, is_remix=self.is_remix) 290 | 291 | 292 | if __name__ == '__main__': 293 | a = torch.rand(16, 3, 224, 224) 294 | net = ResNet50(is_remix=True) 295 | x,y = net(a) 296 | print(x.shape) 297 | print(y.shape) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | print(torch.backends.cudnn.is_available()) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import model 4 | import cv2 as cv 5 | import numpy as np 6 | import torch.nn as nn 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torch.autograd import Variable 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | from EmoDataset import EMODataset 13 | import torch.nn.functional as F 14 | from omegaconf import OmegaConf 15 | from torchvision import models 16 | from model import PerceptualLoss,IdentitySimilarityLoss, PairwiseTransferLoss,crop_and_warp_face, get_foreground_mask,remove_background_and_convert_to_rgb,apply_warping_field 17 | import mediapipe as mp 18 | import torchvision.transforms as transforms 19 | import os 20 | import torchvision.utils as vutils 21 | import time 22 | from torch.cuda.amp import autocast, GradScaler 23 | from torch.autograd import Variable 24 | 25 | from scipy.linalg import sqrtm 26 | from sklearn.metrics.pairwise import cosine_similarity 27 | from lpips import LPIPS 28 | 29 | from torch.utils.tensorboard import SummaryWriter 30 | 31 | output_dir = "output_images" 32 | os.makedirs(output_dir, exist_ok=True) 33 | 34 | face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, min_detection_confidence=0.5) 35 | 36 | use_cuda = torch.cuda.is_available() 37 | device = torch.device("cuda" if use_cuda else "cpu") 38 | 39 | 40 | 41 | 42 | # Function to calculate FID 43 | def calculate_fid(real_images, fake_images): 44 | real_images = real_images.detach().cpu().numpy() 45 | fake_images = fake_images.detach().cpu().numpy() 46 | mu1, sigma1 = real_images.mean(axis=0), np.cov(real_images, rowvar=False) 47 | mu2, sigma2 = fake_images.mean(axis=0), np.cov(fake_images, rowvar=False) 48 | ssdiff = np.sum((mu1 - mu2) ** 2.0) 49 | covmean = sqrtm(sigma1.dot(sigma2)) 50 | if np.iscomplexobj(covmean): 51 | covmean = covmean.real 52 | fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) 53 | return fid 54 | 55 | # Function to calculate CSIM (Cosine Similarity) 56 | def calculate_csim(real_features, fake_features): 57 | csim = cosine_similarity(real_features.detach().cpu().numpy(), fake_features.detach().cpu().numpy()) 58 | return np.mean(csim) 59 | 60 | # Function to calculate LPIPS 61 | def calculate_lpips(real_images, fake_images): 62 | lpips_model = LPIPS(net='alex').cuda() # 'alex', 'vgg', 'squeeze' 63 | lpips_scores = [] 64 | for real, fake in zip(real_images, fake_images): 65 | real = real.unsqueeze(0).cuda() 66 | fake = fake.unsqueeze(0).cuda() 67 | lpips_score = lpips_model(real, fake) 68 | lpips_scores.append(lpips_score.item()) 69 | return np.mean(lpips_scores) 70 | 71 | # align to cyclegan 72 | def discriminator_loss(real_pred, fake_pred, loss_type='lsgan'): 73 | if loss_type == 'lsgan': 74 | real_loss = torch.mean((real_pred - 1)**2) 75 | fake_loss = torch.mean(fake_pred**2) 76 | elif loss_type == 'vanilla': 77 | real_loss = F.binary_cross_entropy_with_logits(real_pred, torch.ones_like(real_pred)) 78 | fake_loss = F.binary_cross_entropy_with_logits(fake_pred, torch.zeros_like(fake_pred)) 79 | else: 80 | raise NotImplementedError(f'Loss type {loss_type} is not implemented.') 81 | 82 | return ((real_loss + fake_loss) * 0.5).requires_grad_() 83 | 84 | 85 | def cosine_loss(positive_pairs, negative_pairs, margin=0.5, scale=5): 86 | """ 87 | Calculates the cosine loss for the positive and negative pairs. 88 | 89 | Args: 90 | positive_pairs (list): List of tuples containing positive pairs (z_i, z_j). 91 | negative_pairs (list): List of tuples containing negative pairs (z_i, z_j). 92 | margin (float): Margin value for the cosine distance (default: 0.5). 93 | scale (float): Scaling factor for the cosine distance (default: 5). 94 | 95 | Returns: 96 | torch.Tensor: Cosine loss value. 97 | """ 98 | def cosine_distance(z_i, z_j): 99 | # Normalize the feature vectors 100 | z_i = F.normalize(z_i, dim=-1) 101 | z_j = F.normalize(z_j, dim=-1) 102 | 103 | # Calculate the cosine similarity 104 | cos_sim = torch.sum(z_i * z_j, dim=-1) 105 | 106 | # Apply the scaling and margin 107 | cos_dist = scale * (cos_sim - margin) 108 | 109 | return cos_dist 110 | 111 | # Calculate the cosine distance for positive pairs 112 | pos_cos_dist = [cosine_distance(z_i, z_j) for z_i, z_j in positive_pairs] 113 | pos_cos_dist = torch.stack(pos_cos_dist) 114 | 115 | # Calculate the cosine distance for negative pairs 116 | neg_cos_dist = [cosine_distance(z_i, z_j) for z_i, z_j in negative_pairs] 117 | neg_cos_dist = torch.stack(neg_cos_dist) 118 | 119 | # Calculate the cosine loss 120 | loss = -torch.log(torch.exp(pos_cos_dist) / (torch.exp(pos_cos_dist) + torch.sum(torch.exp(neg_cos_dist)))) 121 | 122 | return loss.mean().requires_grad_() 123 | 124 | 125 | 126 | 127 | 128 | 129 | def train_base(cfg, Gbase, Dbase, dataloader, start_epoch=0): 130 | patch = (1, cfg.data.train_width // 2 ** 4, cfg.data.train_height // 2 ** 4) 131 | hinge_loss = nn.HingeEmbeddingLoss(reduction='mean') 132 | feature_matching_loss = nn.MSELoss() 133 | Gbase.train() 134 | Dbase.train() 135 | optimizer_G = torch.optim.AdamW(Gbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 136 | optimizer_D = torch.optim.AdamW(Dbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 137 | scheduler_G = CosineAnnealingLR(optimizer_G, T_max=cfg.training.base_epochs, eta_min=1e-6) 138 | scheduler_D = CosineAnnealingLR(optimizer_D, T_max=cfg.training.base_epochs, eta_min=1e-6) 139 | 140 | perceptual_loss_fn = PerceptualLoss(device, weights={'vgg19': 20.0, 'vggface': 4.0, 'gaze': 5.0,'lpips':10.0}) 141 | pairwise_transfer_loss = PairwiseTransferLoss() 142 | # identity_similarity_loss = IdentitySimilarityLoss() 143 | identity_similarity_loss = PerceptualLoss(device, weights={'vgg19': 0.0, 'vggface': 1.0, 'gaze': 0.0,'lpips':0.0}) # focus on face 144 | 145 | scaler = GradScaler() 146 | writer = SummaryWriter(log_dir='runs/training_logs') 147 | 148 | for epoch in range(start_epoch, cfg.training.base_epochs): 149 | print("Epoch:", epoch) 150 | 151 | epoch_loss_G = 0 152 | epoch_loss_D = 0 153 | 154 | fid_score = 0 155 | csim_score = 0 156 | lpips_score = 0 157 | 158 | 159 | 160 | for batch in dataloader: 161 | 162 | source_frames = batch['source_frames'] 163 | driving_frames = batch['driving_frames'] 164 | video_id = batch['video_id'][0] 165 | 166 | # Access videos from dataloader2 for cycle consistency 167 | source_frames2 = batch['source_frames_star'] 168 | driving_frames2 = batch['driving_frames_star'] 169 | video_id2 = batch['video_id_star'][0] 170 | 171 | 172 | num_frames = len(driving_frames) 173 | len_source_frames = len(source_frames) 174 | len_driving_frames = len(driving_frames) 175 | len_source_frames2 = len(source_frames2) 176 | len_driving_frames2 = len(driving_frames2) 177 | 178 | 179 | for idx in range(num_frames): 180 | # loop around if idx exceeds video length 181 | source_frame = source_frames[idx % len_source_frames].to(device) 182 | driving_frame = driving_frames[idx % len_driving_frames].to(device) 183 | 184 | source_frame_star = source_frames2[idx % len_source_frames2].to(device) 185 | driving_frame_star = driving_frames2[idx % len_driving_frames2].to(device) 186 | 187 | 188 | with autocast(): 189 | 190 | # We use multiple loss functions for training, which can be split into two groups. 191 | # The first group consists of the standard training objectives for image synthesis. 192 | # These include perceptual [14] and GAN [ 33 ] losses that match 193 | # the predicted image ˆx𝑠→𝑑 to the ground-truth x𝑑 . 194 | pred_frame,pred_pyramids = Gbase(source_frame, driving_frame) 195 | 196 | # Obtain the foreground mask for the driving image 197 | # foreground_mask = get_foreground_mask(source_frame) 198 | 199 | # # Move the foreground mask to the same device as output_frame 200 | # foreground_mask = foreground_mask.to(pred_frame.device) 201 | 202 | # # Multiply the predicted and driving images with the foreground mask 203 | # # masked_predicted_image = pred_frame * foreground_mask 204 | # masked_target_image = driving_frame * foreground_mask 205 | 206 | save_images = True 207 | # Save the images 208 | if save_images: 209 | # vutils.save_image(source_frame, f"{output_dir}/source_frame_{idx}.png") 210 | # vutils.save_image(driving_frame, f"{output_dir}/driving_frame_{idx}.png") 211 | vutils.save_image(pred_frame, f"{output_dir}/pred_frame_{idx}.png") 212 | # vutils.save_image(source_frame_star, f"{output_dir}/source_frame_star_{idx}.png") 213 | # vutils.save_image(driving_frame_star, f"{output_dir}/driving_frame_star_{idx}.png") 214 | # vutils.save_image(masked_predicted_image, f"{output_dir}/masked_predicted_image_{idx}.png") 215 | # vutils.save_image(masked_target_image, f"{output_dir}/masked_target_image_{idx}.png") 216 | 217 | # Calculate perceptual losses - use pyramid 218 | # loss_G_per = perceptual_loss_fn(pred_frame, source_frame) 219 | 220 | loss_G_per = 0 221 | for scale, pred_scaled in pred_pyramids.items(): 222 | target_scaled = F.interpolate(driving_frame, size=pred_scaled.shape[2:], mode='bilinear', align_corners=False) 223 | loss_G_per += perceptual_loss_fn(pred_scaled, target_scaled) 224 | 225 | # Adversarial ground truths - from Kevin Fringe 226 | valid = Variable(torch.Tensor(np.ones((driving_frame.size(0), *patch))), requires_grad=False).to(device) 227 | fake = Variable(torch.Tensor(-1 * np.ones((driving_frame.size(0), *patch))), requires_grad=False).to(device) 228 | 229 | # real loss 230 | real_pred = Dbase(driving_frame, source_frame) 231 | loss_real = hinge_loss(real_pred, valid) 232 | 233 | # fake loss 234 | fake_pred = Dbase(pred_frame.detach(), source_frame) 235 | loss_fake = hinge_loss(fake_pred, fake) 236 | 237 | # Train discriminator 238 | optimizer_D.zero_grad() 239 | 240 | # Calculate adversarial losses 241 | real_pred = Dbase(driving_frame, source_frame) 242 | fake_pred = Dbase(pred_frame.detach(), source_frame) 243 | loss_D = discriminator_loss(real_pred, fake_pred, loss_type='lsgan') 244 | 245 | scaler.scale(loss_D).backward() 246 | scaler.step(optimizer_D) 247 | scaler.update() 248 | 249 | # Calculate adversarial losses 250 | loss_G_adv = 0.5 * (loss_real + loss_fake) 251 | 252 | # Feature matching loss 253 | loss_fm = feature_matching_loss(pred_frame, driving_frame) 254 | writer.add_scalar('Loss/Feature Matching', loss_fm, epoch) 255 | 256 | 257 | 258 | 259 | # New disentangling losses - from VASA paper 260 | # I1 and I2 are from the same video, I3 and I4 are from different videos 261 | 262 | # Get the next frame index, wrapping around if necessary 263 | next_idx = (idx + 20) % len_source_frames 264 | 265 | I1 = source_frame 266 | I2 = source_frames[next_idx].to(device) 267 | I3 = source_frame_star 268 | I4 = source_frames2[next_idx % len_source_frames2].to(device) 269 | loss_pairwise = pairwise_transfer_loss(Gbase,I1, I2) 270 | loss_identity = identity_similarity_loss(I3, I4) 271 | 272 | 273 | writer.add_scalar('pairwise_transfer_loss', loss_pairwise, epoch) 274 | writer.add_scalar('identity_similarity_loss', loss_identity, epoch) 275 | 276 | 277 | # The other objective CycleGAN regularizes the training and introduces disentanglement between the motion and canonical space 278 | # In order to calculate this loss, we use an additional source-driving pair x𝑠∗ and x𝑑∗ , 279 | # which is sampled from a different video! and therefore has different appearance from the current x𝑠 , x𝑑 pair. 280 | 281 | # produce the following cross-reenacted image: ˆx𝑠∗→𝑑 = Gbase (x𝑠∗ , x𝑑 ) 282 | # 283 | cross_reenacted_image,_ = Gbase(source_frame_star, driving_frame) 284 | if save_images: 285 | vutils.save_image(cross_reenacted_image, f"{output_dir}/cross_reenacted_image_{idx}.png") 286 | 287 | # # Store the motion descriptors z𝑠→𝑑(predicted) and z𝑠∗→𝑑 (star predicted) from the 288 | # # respective forward passes of the base network. 289 | _, _, z_pred = Gbase.motionEncoder(pred_frame) 290 | _, _, zd = Gbase.motionEncoder(driving_frame) 291 | 292 | _, _, z_star__pred = Gbase.motionEncoder(cross_reenacted_image) 293 | _, _, zd_star = Gbase.motionEncoder(driving_frame_star) 294 | 295 | 296 | # # Calculate cycle consistency loss 297 | # # We then arrange the motion descriptors into positive pairs P that 298 | # # should align with each other: P = (z𝑠→𝑑 , z𝑑 ), (z𝑠∗→𝑑 , z𝑑 ) , and 299 | # # the negative pairs: N = (z𝑠→𝑑 , z𝑑∗ ), (z𝑠∗→𝑑 , z𝑑∗ ) . These pairs are 300 | # # used to calculate the following cosine distance: 301 | 302 | P = [(z_pred, zd) ,(z_star__pred, zd)] 303 | N = [(z_pred, zd_star),(z_star__pred, zd_star)] 304 | loss_G_cos = cosine_loss(P, N) 305 | 306 | 307 | writer.add_scalar('Cycle consistency loss', loss_G_cos, epoch) 308 | 309 | # Backpropagate and update generator 310 | optimizer_G.zero_grad() 311 | # Total generator loss 312 | total_loss = cfg.training.w_per * loss_G_per + \ 313 | cfg.training.w_adv * loss_G_adv + \ 314 | cfg.training.w_fm * loss_fm + \ 315 | cfg.training.w_cos * loss_G_cos + \ 316 | cfg.training.w_pairwise * loss_pairwise + \ 317 | cfg.training.w_identity * loss_identity 318 | scaler.scale(total_loss).backward() 319 | scaler.step(optimizer_G) 320 | scaler.update() 321 | 322 | 323 | epoch_loss_G += total_loss.item() 324 | epoch_loss_D += loss_D.item() 325 | 326 | 327 | 328 | 329 | 330 | avg_loss_G = epoch_loss_G / len(dataloader) 331 | avg_loss_D = epoch_loss_D / len(dataloader) 332 | 333 | writer.add_scalar('Loss/Generator', avg_loss_G, epoch) 334 | writer.add_scalar('Loss/Discriminator', avg_loss_D, epoch) 335 | 336 | 337 | writer.add_scalar('FID Score', fid_score, epoch) 338 | writer.add_scalar('CSIM Score', csim_score, epoch) 339 | writer.add_scalar('LPIPS Score', lpips_score, epoch) 340 | 341 | 342 | scheduler_G.step() 343 | scheduler_D.step() 344 | 345 | if (epoch + 1) % cfg.training.log_interval == 0: 346 | print(f"Epoch [{epoch+1}/{cfg.training.base_epochs}], " 347 | f"Loss_G: {loss_G_cos.item():.4f}, Loss_D: {loss_D.item():.4f}") 348 | 349 | if (epoch + 1) % cfg.training.save_interval == 0: 350 | torch.save({ 351 | 'epoch': epoch, 352 | 'model_G_state_dict': Gbase.state_dict(), 353 | 'model_D_state_dict': Dbase.state_dict(), 354 | 'optimizer_G_state_dict': optimizer_G.state_dict(), 355 | 'optimizer_D_state_dict': optimizer_D.state_dict(), 356 | }, f"checkpoint_epoch{epoch+1}.pth") 357 | 358 | # Calculate FID score for the current epoch 359 | # with torch.no_grad(): 360 | # real_images = torch.cat(real_images) 361 | # fake_images = torch.cat(fake_images) 362 | # fid_score = calculate_fid(real_images, fake_images) 363 | # csim_score = calculate_csim(real_images, fake_images) 364 | # lpips_score = calculate_lpips(real_images, fake_images) 365 | 366 | # writer.add_scalar('FID Score', fid_score, epoch) 367 | # writer.add_scalar('CSIM Score', csim_score, epoch) 368 | # writer.add_scalar('LPIPS Score', lpips_score, epoch) 369 | 370 | 371 | 372 | def load_checkpoint(checkpoint_path, model_G, model_D, optimizer_G, optimizer_D): 373 | if os.path.isfile(checkpoint_path): 374 | print(f"Loading checkpoint '{checkpoint_path}'") 375 | checkpoint = torch.load(checkpoint_path) 376 | model_G.load_state_dict(checkpoint['model_G_state_dict']) 377 | model_D.load_state_dict(checkpoint['model_D_state_dict']) 378 | optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) 379 | optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict']) 380 | start_epoch = checkpoint['epoch'] + 1 381 | print(f"Loaded checkpoint '{checkpoint_path}' (epoch {checkpoint['epoch']})") 382 | else: 383 | print(f"No checkpoint found at '{checkpoint_path}'") 384 | start_epoch = 0 385 | return start_epoch 386 | 387 | def main(cfg: OmegaConf) -> None: 388 | use_cuda = torch.cuda.is_available() 389 | device = torch.device("cuda" if use_cuda else "cpu") 390 | 391 | transform = transforms.Compose([ 392 | transforms.ToTensor(), 393 | transforms.RandomHorizontalFlip(), 394 | transforms.ColorJitter(), 395 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 396 | ]) 397 | 398 | dataset = EMODataset( 399 | use_gpu=use_cuda, 400 | remove_background=True, 401 | width=cfg.data.train_width, 402 | height=cfg.data.train_height, 403 | n_sample_frames=cfg.training.n_sample_frames, 404 | sample_rate=cfg.training.sample_rate, 405 | img_scale=(1.0, 1.0), 406 | video_dir=cfg.training.video_dir, 407 | json_file=cfg.training.json_file, 408 | transform=transform, 409 | apply_crop_warping=True 410 | ) 411 | 412 | 413 | 414 | dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0) 415 | 416 | 417 | Gbase = model.Gbase().to(device) 418 | Dbase = model.Discriminator().to(device) 419 | 420 | optimizer_G = torch.optim.AdamW(Gbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 421 | optimizer_D = torch.optim.AdamW(Dbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 422 | 423 | # Load checkpoint if available 424 | checkpoint_path = cfg.training.checkpoint_path 425 | start_epoch = load_checkpoint(checkpoint_path, Gbase, Dbase, optimizer_G, optimizer_D) 426 | 427 | 428 | train_base(cfg, Gbase, Dbase, dataloader, start_epoch) 429 | torch.save(Gbase.state_dict(), 'Gbase.pth') 430 | torch.save(Dbase.state_dict(), 'Dbase.pth') 431 | 432 | 433 | if __name__ == "__main__": 434 | config = OmegaConf.load("./configs/training/stage1-base.yaml") 435 | main(config) -------------------------------------------------------------------------------- /train_highres.py: -------------------------------------------------------------------------------- 1 | # not convinced we need to train this - see metaportrait Super Resolution model 2 | # https://github.com/Meta-Portrait/MetaPortrait/tree/main/sr_model 3 | import argparse 4 | import torch 5 | import model 6 | import cv2 as cv 7 | import numpy as np 8 | import torch.nn as nn 9 | from PIL import Image 10 | from torch.utils.data import DataLoader 11 | from torchvision import transforms 12 | from torch.autograd import Variable 13 | from torch.optim.lr_scheduler import CosineAnnealingLR 14 | from EmoDataset import EMODataset 15 | import torch.nn.functional as F 16 | import decord 17 | from omegaconf import OmegaConf 18 | from torchvision import models 19 | from model import MPGazeLoss,Encoder 20 | from rome_losses import Vgg19 # use vgg19 for perceptualloss 21 | import cv2 22 | import mediapipe as mp 23 | from memory_profiler import profile 24 | import torchvision.transforms as transforms 25 | import os 26 | from torchvision.utils import save_image 27 | 28 | 29 | 30 | # Create a directory to save the images (if it doesn't already exist) 31 | output_dir = "output_images" 32 | os.makedirs(output_dir, exist_ok=True) 33 | 34 | 35 | face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, min_detection_confidence=0.5) 36 | 37 | use_cuda = torch.cuda.is_available() 38 | device = torch.device("cuda" if use_cuda else "cpu") 39 | # torch.autograd.set_detect_anomaly(True)# this slows thing down - only for debug 40 | 41 | 42 | 43 | ''' 44 | We load the pre-trained DeepLabV3 model using models.segmentation.deeplabv3_resnet101(pretrained=True). This model is based on the ResNet-101 backbone and is pre-trained on the COCO dataset. 45 | We define the necessary image transformations using transforms.Compose. The transformations include converting the image to a tensor and normalizing it using the mean and standard deviation values specific to the model. 46 | We apply the transformations to the input image using transform(image) and add an extra dimension to represent the batch size using unsqueeze(0). 47 | We move the input tensor to the same device as the model to ensure compatibility. 48 | We perform the segmentation by passing the input tensor through the model using model(input_tensor). The output is a dictionary containing the segmentation map. 49 | We obtain the predicted segmentation mask by taking the argmax of the output along the channel dimension using torch.max(output['out'], dim=1). 50 | We convert the segmentation mask to a binary foreground mask by comparing the predicted class labels with the class index representing the person class (assuming it is 15 in this example). The resulting mask will have values of 1 for foreground pixels and 0 for background pixels. 51 | Finally, we return the foreground mask. 52 | ''' 53 | 54 | def get_foreground_mask(image): 55 | # Load the pre-trained DeepLabV3 model 56 | model = models.segmentation.deeplabv3_resnet101(pretrained=True) 57 | model.eval() 58 | 59 | # Define the image transformations 60 | transform = transforms.Compose([ 61 | transforms.ToTensor(), 62 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 63 | ]) 64 | 65 | # Apply the transformations to the input image 66 | input_tensor = transform(image).unsqueeze(0) 67 | 68 | # Move the input tensor to the same device as the model 69 | device = next(model.parameters()).device 70 | input_tensor = input_tensor.to(device) 71 | 72 | # Perform the segmentation 73 | with torch.no_grad(): 74 | output = model(input_tensor) 75 | 76 | # Get the predicted segmentation mask 77 | _, mask = torch.max(output['out'], dim=1) 78 | 79 | # Convert the segmentation mask to a binary foreground mask 80 | foreground_mask = (mask == 15).float() # Assuming class 15 represents the person class 81 | 82 | return foreground_mask 83 | 84 | 85 | ''' 86 | Perceptual Loss: 87 | 88 | The PerceptualLoss class combines losses from VGG19, VGG Face, and a specialized gaze loss. 89 | It computes the perceptual losses by passing the output and target frames through the respective models and calculating the MSE loss between the features. 90 | The total perceptual loss is a weighted sum of the individual losses. 91 | 92 | 93 | Adversarial Loss: 94 | 95 | The adversarial_loss function computes the adversarial loss for the generator. 96 | It passes the generated output frame through the discriminator and calculates the MSE loss between the predicted values and a tensor of ones (indicating real samples). 97 | 98 | 99 | Cycle Consistency Loss: 100 | 101 | The cycle_consistency_loss function computes the cycle consistency loss. 102 | It passes the output frame and the source frame through the generator to reconstruct the source frame. 103 | The L1 loss is calculated between the reconstructed source frame and the original source frame. 104 | 105 | 106 | Contrastive Loss: 107 | 108 | The contrastive_loss function computes the contrastive loss using cosine similarity. 109 | It calculates the cosine similarity between positive pairs (output-source, output-driving) and negative pairs (output-random, source-random). 110 | The loss is computed as the negative log likelihood of the positive pairs over the sum of positive and negative pair similarities. 111 | The neg_pair_loss function calculates the loss for negative pairs using a margin. 112 | 113 | 114 | Discriminator Loss: 115 | 116 | The discriminator_loss function computes the loss for the discriminator. 117 | It calculates the MSE loss between the predicted values for real samples and a tensor of ones, and the MSE loss between the predicted values for fake samples and a tensor of zeros. 118 | The total discriminator loss is the sum of the real and fake losses. 119 | ''' 120 | 121 | # @profile 122 | def adversarial_loss(output_frame, discriminator): 123 | fake_pred = discriminator(output_frame) 124 | loss = F.mse_loss(fake_pred, torch.ones_like(fake_pred)) 125 | return loss.requires_grad_() 126 | 127 | # @profile 128 | def cycle_consistency_loss(output_frame, source_frame, driving_frame, generator): 129 | reconstructed_source = generator(output_frame, source_frame) 130 | loss = F.l1_loss(reconstructed_source, source_frame) 131 | return loss.requires_grad_() 132 | 133 | 134 | def contrastive_loss(output_frame, source_frame, driving_frame, encoder, margin=1.0): 135 | z_out = encoder(output_frame) 136 | z_src = encoder(source_frame) 137 | z_drv = encoder(driving_frame) 138 | z_rand = torch.randn_like(z_out, requires_grad=True) 139 | 140 | pos_pairs = [(z_out, z_src), (z_out, z_drv)] 141 | neg_pairs = [(z_out, z_rand), (z_src, z_rand)] 142 | 143 | loss = torch.tensor(0.0, requires_grad=True).to(device) 144 | for pos_pair in pos_pairs: 145 | loss = loss + torch.log(torch.exp(F.cosine_similarity(pos_pair[0], pos_pair[1])) / 146 | (torch.exp(F.cosine_similarity(pos_pair[0], pos_pair[1])) + 147 | neg_pair_loss(pos_pair, neg_pairs, margin))) 148 | 149 | return loss 150 | 151 | def neg_pair_loss(pos_pair, neg_pairs, margin): 152 | loss = torch.tensor(0.0, requires_grad=True).to(device) 153 | for neg_pair in neg_pairs: 154 | loss = loss + torch.exp(F.cosine_similarity(pos_pair[0], neg_pair[1]) - margin) 155 | return loss 156 | # @profile 157 | def discriminator_loss(real_pred, fake_pred): 158 | real_loss = F.mse_loss(real_pred, torch.ones_like(real_pred)) 159 | fake_loss = F.mse_loss(fake_pred, torch.zeros_like(fake_pred)) 160 | return (real_loss + fake_loss).requires_grad_() 161 | 162 | 163 | # @profile 164 | def gaze_loss_fn(predicted_gaze, target_gaze, face_image): 165 | # Ensure face_image has shape (C, H, W) 166 | if face_image.dim() == 4 and face_image.shape[0] == 1: 167 | face_image = face_image.squeeze(0) 168 | if face_image.dim() != 3 or face_image.shape[0] not in [1, 3]: 169 | raise ValueError(f"Expected face_image of shape (C, H, W), got {face_image.shape}") 170 | 171 | # Convert face image from tensor to numpy array 172 | face_image = face_image.detach().cpu().numpy() 173 | if face_image.shape[0] == 3: # if channels are first 174 | face_image = face_image.transpose(1, 2, 0) 175 | face_image = (face_image * 255).astype(np.uint8) 176 | 177 | # Extract eye landmarks using MediaPipe 178 | results = face_mesh.process(cv2.cvtColor(face_image, cv2.COLOR_RGB2BGR)) 179 | if not results.multi_face_landmarks: 180 | return torch.tensor(0.0, requires_grad=True).to(device) 181 | 182 | eye_landmarks = [] 183 | for face_landmarks in results.multi_face_landmarks: 184 | left_eye_landmarks = [face_landmarks.landmark[idx] for idx in mp.solutions.face_mesh.FACEMESH_LEFT_EYE] 185 | right_eye_landmarks = [face_landmarks.landmark[idx] for idx in mp.solutions.face_mesh.FACEMESH_RIGHT_EYE] 186 | eye_landmarks.append((left_eye_landmarks, right_eye_landmarks)) 187 | 188 | # Compute loss for each eye 189 | loss = 0.0 190 | h, w = face_image.shape[:2] 191 | for left_eye, right_eye in eye_landmarks: 192 | # Convert landmarks to pixel coordinates 193 | left_eye_pixels = [(int(lm.x * w), int(lm.y * h)) for lm in left_eye] 194 | right_eye_pixels = [(int(lm.x * w), int(lm.y * h)) for lm in right_eye] 195 | 196 | # Create eye mask 197 | left_mask = torch.zeros((1, h, w), requires_grad=True).to(device) 198 | right_mask = torch.zeros((1, h, w), requires_grad=True).to(device) 199 | cv2.fillPoly(left_mask[0].cpu().numpy(), [np.array(left_eye_pixels)], 1.0) 200 | cv2.fillPoly(right_mask[0].cpu().numpy(), [np.array(right_eye_pixels)], 1.0) 201 | 202 | # Compute gaze loss for each eye 203 | left_gaze_loss = F.mse_loss(predicted_gaze * left_mask, target_gaze * left_mask) 204 | right_gaze_loss = F.mse_loss(predicted_gaze * right_mask, target_gaze * right_mask) 205 | loss += left_gaze_loss + right_gaze_loss 206 | 207 | return loss / len(eye_landmarks) 208 | 209 | 210 | def train_base(cfg, Gbase, Dbase, dataloader): 211 | Gbase.train() 212 | Dbase.train() 213 | optimizer_G = torch.optim.AdamW(Gbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 214 | optimizer_D = torch.optim.AdamW(Dbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 215 | scheduler_G = CosineAnnealingLR(optimizer_G, T_max=cfg.training.base_epochs, eta_min=1e-6) 216 | scheduler_D = CosineAnnealingLR(optimizer_D, T_max=cfg.training.base_epochs, eta_min=1e-6) 217 | 218 | vgg19 = Vgg19().to(device) 219 | perceptual_loss_fn = nn.L1Loss().to(device) 220 | # gaze_loss_fn = MPGazeLoss(device) 221 | encoder = Encoder(input_nc=3, output_nc=256).to(device) 222 | 223 | for epoch in range(cfg.training.base_epochs): 224 | print("epoch:", epoch) 225 | for batch in dataloader: 226 | source_frames = batch['source_frames'] #.to(device) 227 | driving_frames = batch['driving_frames'] #.to(device) 228 | 229 | num_frames = len(source_frames) # Get the number of frames in the batch 230 | 231 | for idx in range(num_frames): 232 | source_frame = source_frames[idx].to(device) 233 | driving_frame = driving_frames[idx].to(device) 234 | 235 | # Train generator 236 | optimizer_G.zero_grad() 237 | output_frame = Gbase(source_frame, driving_frame) 238 | 239 | # Resize output_frame to 256x256 to match the driving_frame size 240 | output_frame = F.interpolate(output_frame, size=(256, 256), mode='bilinear', align_corners=False) 241 | 242 | 243 | # 💀 Compute losses - "losses are calculated using ONLY foreground regions" 244 | # Obtain the foreground mask for the target image 245 | foreground_mask = get_foreground_mask(source_frame) 246 | 247 | # Multiply the predicted and target images with the foreground mask 248 | masked_predicted_image = output_frame * foreground_mask 249 | masked_target_image = source_frame * foreground_mask 250 | 251 | 252 | output_vgg_features = vgg19(masked_predicted_image) 253 | driving_vgg_features = vgg19(masked_target_image) 254 | total_loss = 0 255 | 256 | for output_feat, driving_feat in zip(output_vgg_features, driving_vgg_features): 257 | total_loss = total_loss + perceptual_loss_fn(output_feat, driving_feat.detach()) 258 | 259 | loss_adversarial = adversarial_loss(masked_predicted_image, Dbase) 260 | 261 | loss_gaze = gaze_loss_fn(output_frame, driving_frame, source_frame) # 🤷 fix this 262 | # Combine the losses and perform backpropagation and optimization 263 | total_loss = total_loss + loss_adversarial + loss_gaze 264 | 265 | 266 | # Accumulate gradients 267 | loss_gaze.backward() 268 | total_loss.backward(retain_graph=True) 269 | loss_adversarial.backward() 270 | 271 | # Update generator 272 | optimizer_G.step() 273 | 274 | # Train discriminator 275 | optimizer_D.zero_grad() 276 | real_pred = Dbase(driving_frame) 277 | fake_pred = Dbase(output_frame.detach()) 278 | loss_D = discriminator_loss(real_pred, fake_pred) 279 | 280 | # Backpropagate and update discriminator 281 | loss_D.backward() 282 | optimizer_D.step() 283 | 284 | 285 | # Update learning rates 286 | scheduler_G.step() 287 | scheduler_D.step() 288 | 289 | # Log and save checkpoints 290 | if (epoch + 1) % cfg.training.log_interval == 0: 291 | print(f"Epoch [{epoch+1}/{cfg.training.base_epochs}], " 292 | f"Loss_G: {loss_gaze.item():.4f}, Loss_D: {loss_D.item():.4f}") 293 | if (epoch + 1) % cfg.training.save_interval == 0: 294 | torch.save(Gbase.state_dict(), f"Gbase_epoch{epoch+1}.pth") 295 | torch.save(Dbase.state_dict(), f"Dbase_epoch{epoch+1}.pth") 296 | 297 | def train_hr(cfg, GHR, Genh, dataloader_hr): 298 | GHR.train() 299 | Genh.train() 300 | 301 | vgg19 = Vgg19().to(device) 302 | perceptual_loss_fn = nn.L1Loss().to(device) 303 | # gaze_loss_fn = MPGazeLoss(device=device) 304 | 305 | optimizer_G = torch.optim.AdamW(Genh.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 306 | scheduler_G = CosineAnnealingLR(optimizer_G, T_max=cfg.training.hr_epochs, eta_min=1e-6) 307 | 308 | for epoch in range(cfg.training.hr_epochs): 309 | for batch in dataloader_hr: 310 | source_frames = batch['source_frames'].to(device) 311 | driving_frames = batch['driving_frames'].to(device) 312 | 313 | num_frames = len(source_frames) # Get the number of frames in the batch 314 | 315 | for idx in range(num_frames): 316 | source_frame = source_frames[idx] 317 | driving_frame = driving_frames[idx] 318 | 319 | # Generate output frame using pre-trained base model 320 | with torch.no_grad(): 321 | xhat_base = GHR.Gbase(source_frame, driving_frame) 322 | 323 | # Train high-resolution model 324 | optimizer_G.zero_grad() 325 | xhat_hr = Genh(xhat_base) 326 | 327 | 328 | # Compute losses - option 1 329 | # loss_supervised = Genh.supervised_loss(xhat_hr, driving_frame) 330 | # loss_unsupervised = Genh.unsupervised_loss(xhat_base, xhat_hr) 331 | # loss_perceptual = perceptual_loss_fn(xhat_hr, driving_frame) 332 | 333 | # option2 ? 🤷 use vgg19 as per metaportrait? 334 | # - Compute losses 335 | xhat_hr_vgg_features = vgg19(xhat_hr) 336 | driving_vgg_features = vgg19(driving_frame) 337 | loss_perceptual = 0 338 | for xhat_hr_feat, driving_feat in zip(xhat_hr_vgg_features, driving_vgg_features): 339 | loss_perceptual += perceptual_loss_fn(xhat_hr_feat, driving_feat.detach()) 340 | 341 | loss_supervised = perceptual_loss_fn(xhat_hr, driving_frame) 342 | loss_unsupervised = perceptual_loss_fn(xhat_hr, xhat_base) 343 | loss_gaze = gaze_loss_fn(xhat_hr, driving_frame) 344 | loss_G = ( 345 | cfg.training.lambda_supervised * loss_supervised 346 | + cfg.training.lambda_unsupervised * loss_unsupervised 347 | + cfg.training.lambda_perceptual * loss_perceptual 348 | + cfg.training.lambda_gaze * loss_gaze 349 | ) 350 | 351 | # Backpropagate and update high-resolution model 352 | loss_G.backward() 353 | optimizer_G.step() 354 | 355 | # Update learning rate 356 | scheduler_G.step() 357 | 358 | # Log and save checkpoints 359 | if (epoch + 1) % cfg.training.log_interval == 0: 360 | print(f"Epoch [{epoch+1}/{cfg.training.hr_epochs}], " 361 | f"Loss_G: {loss_G.item():.4f}") 362 | if (epoch + 1) % cfg.training.save_interval == 0: 363 | torch.save(Genh.state_dict(), f"Genh_epoch{epoch+1}.pth") 364 | 365 | 366 | def train_student(cfg, Student, GHR, dataloader_avatars): 367 | Student.train() 368 | 369 | optimizer_S = torch.optim.AdamW(Student.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 370 | 371 | scheduler_S = CosineAnnealingLR(optimizer_S, T_max=cfg.training.student_epochs, eta_min=1e-6) 372 | 373 | for epoch in range(cfg.training.student_epochs): 374 | for batch in dataloader_avatars: 375 | avatar_indices = batch['avatar_indices'].to(device) 376 | driving_frames = batch['driving_frames'].to(device) 377 | 378 | # Generate high-resolution output frames using pre-trained HR model 379 | with torch.no_grad(): 380 | xhat_hr = GHR(driving_frames) 381 | 382 | # Train student model 383 | optimizer_S.zero_grad() 384 | 385 | # Generate output frames using student model 386 | xhat_student = Student(driving_frames, avatar_indices) 387 | 388 | # Compute loss 389 | loss_S = F.mse_loss(xhat_student, xhat_hr) 390 | 391 | # Backpropagate and update student model 392 | loss_S.backward() 393 | optimizer_S.step() 394 | 395 | # Update learning rate 396 | scheduler_S.step() 397 | 398 | # Log and save checkpoints 399 | if (epoch + 1) % cfg.training.log_interval == 0: 400 | print(f"Epoch [{epoch+1}/{cfg.training.student_epochs}], " 401 | f"Loss_S: {loss_S.item():.4f}") 402 | 403 | if (epoch + 1) % cfg.training.save_interval == 0: 404 | torch.save(Student.state_dict(), f"Student_epoch{epoch+1}.pth") 405 | 406 | def main(cfg: OmegaConf) -> None: 407 | use_cuda = torch.cuda.is_available() 408 | device = torch.device("cuda" if use_cuda else "cpu") 409 | 410 | transform = transforms.Compose([ 411 | transforms.ToTensor(), 412 | transforms.Normalize([0.5], [0.5]), 413 | transforms.RandomHorizontalFlip(), 414 | transforms.ColorJitter() # as augmentation for both source and target images, we use color jitter and random flip 415 | ]) 416 | 417 | dataset = EMODataset( 418 | use_gpu=use_cuda, 419 | width=cfg.data.train_width, 420 | height=cfg.data.train_height, 421 | n_sample_frames=cfg.training.n_sample_frames, 422 | sample_rate=cfg.training.sample_rate, 423 | img_scale=(1.0, 1.0), 424 | video_dir=cfg.training.video_dir, 425 | json_file=cfg.training.json_file, 426 | transform=transform 427 | ) 428 | 429 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4) 430 | 431 | Gbase = model.Gbase() 432 | Dbase = model.Discriminator(input_nc=3).to(device) # 🤷 433 | 434 | train_base(cfg, Gbase, Dbase, dataloader) 435 | 436 | GHR = model.GHR() 437 | GHR.Gbase.load_state_dict(Gbase.state_dict()) 438 | Dhr = model.Discriminator(input_nc=3).to(device) # 🤷 439 | train_hr(cfg, GHR, Dhr, dataloader) 440 | 441 | Student = model.Student(num_avatars=100) # this should equal the number of celebs in dataset 442 | train_student(cfg, Student, GHR, dataloader) 443 | 444 | torch.save(Gbase.state_dict(), 'Gbase.pth') 445 | torch.save(GHR.state_dict(), 'GHR.pth') 446 | torch.save(Student.state_dict(), 'Student.pth') 447 | 448 | if __name__ == "__main__": 449 | config = OmegaConf.load("./configs/training/stage1-base.yaml") 450 | main(config) -------------------------------------------------------------------------------- /train_student.py: -------------------------------------------------------------------------------- 1 | # Todo 2 | import argparse 3 | import torch 4 | import model 5 | import cv2 as cv 6 | import numpy as np 7 | import torch.nn as nn 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from torch.autograd import Variable 12 | from torch.optim.lr_scheduler import CosineAnnealingLR 13 | from EmoDataset import EMODataset 14 | import torch.nn.functional as F 15 | import decord 16 | from omegaconf import OmegaConf 17 | from torchvision import models 18 | from model import MPGazeLoss,Encoder 19 | from rome_losses import Vgg19 # use vgg19 for perceptualloss 20 | import cv2 21 | import mediapipe as mp 22 | from memory_profiler import profile 23 | import torchvision.transforms as transforms 24 | import os 25 | from torchvision.utils import save_image 26 | 27 | 28 | 29 | # Create a directory to save the images (if it doesn't already exist) 30 | output_dir = "output_images" 31 | os.makedirs(output_dir, exist_ok=True) 32 | 33 | 34 | face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, min_detection_confidence=0.5) 35 | 36 | use_cuda = torch.cuda.is_available() 37 | device = torch.device("cuda" if use_cuda else "cpu") 38 | # torch.autograd.set_detect_anomaly(True)# this slows thing down - only for debug 39 | 40 | 41 | 42 | ''' 43 | We load the pre-trained DeepLabV3 model using models.segmentation.deeplabv3_resnet101(pretrained=True). This model is based on the ResNet-101 backbone and is pre-trained on the COCO dataset. 44 | We define the necessary image transformations using transforms.Compose. The transformations include converting the image to a tensor and normalizing it using the mean and standard deviation values specific to the model. 45 | We apply the transformations to the input image using transform(image) and add an extra dimension to represent the batch size using unsqueeze(0). 46 | We move the input tensor to the same device as the model to ensure compatibility. 47 | We perform the segmentation by passing the input tensor through the model using model(input_tensor). The output is a dictionary containing the segmentation map. 48 | We obtain the predicted segmentation mask by taking the argmax of the output along the channel dimension using torch.max(output['out'], dim=1). 49 | We convert the segmentation mask to a binary foreground mask by comparing the predicted class labels with the class index representing the person class (assuming it is 15 in this example). The resulting mask will have values of 1 for foreground pixels and 0 for background pixels. 50 | Finally, we return the foreground mask. 51 | ''' 52 | 53 | def get_foreground_mask(image): 54 | # Load the pre-trained DeepLabV3 model 55 | model = models.segmentation.deeplabv3_resnet101(pretrained=True) 56 | model.eval() 57 | 58 | # Define the image transformations 59 | transform = transforms.Compose([ 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 62 | ]) 63 | 64 | # Apply the transformations to the input image 65 | input_tensor = transform(image).unsqueeze(0) 66 | 67 | # Move the input tensor to the same device as the model 68 | device = next(model.parameters()).device 69 | input_tensor = input_tensor.to(device) 70 | 71 | # Perform the segmentation 72 | with torch.no_grad(): 73 | output = model(input_tensor) 74 | 75 | # Get the predicted segmentation mask 76 | _, mask = torch.max(output['out'], dim=1) 77 | 78 | # Convert the segmentation mask to a binary foreground mask 79 | foreground_mask = (mask == 15).float() # Assuming class 15 represents the person class 80 | 81 | return foreground_mask 82 | 83 | 84 | ''' 85 | Perceptual Loss: 86 | 87 | The PerceptualLoss class combines losses from VGG19, VGG Face, and a specialized gaze loss. 88 | It computes the perceptual losses by passing the output and target frames through the respective models and calculating the MSE loss between the features. 89 | The total perceptual loss is a weighted sum of the individual losses. 90 | 91 | 92 | Adversarial Loss: 93 | 94 | The adversarial_loss function computes the adversarial loss for the generator. 95 | It passes the generated output frame through the discriminator and calculates the MSE loss between the predicted values and a tensor of ones (indicating real samples). 96 | 97 | 98 | Cycle Consistency Loss: 99 | 100 | The cycle_consistency_loss function computes the cycle consistency loss. 101 | It passes the output frame and the source frame through the generator to reconstruct the source frame. 102 | The L1 loss is calculated between the reconstructed source frame and the original source frame. 103 | 104 | 105 | Contrastive Loss: 106 | 107 | The contrastive_loss function computes the contrastive loss using cosine similarity. 108 | It calculates the cosine similarity between positive pairs (output-source, output-driving) and negative pairs (output-random, source-random). 109 | The loss is computed as the negative log likelihood of the positive pairs over the sum of positive and negative pair similarities. 110 | The neg_pair_loss function calculates the loss for negative pairs using a margin. 111 | 112 | 113 | Discriminator Loss: 114 | 115 | The discriminator_loss function computes the loss for the discriminator. 116 | It calculates the MSE loss between the predicted values for real samples and a tensor of ones, and the MSE loss between the predicted values for fake samples and a tensor of zeros. 117 | The total discriminator loss is the sum of the real and fake losses. 118 | ''' 119 | 120 | # @profile 121 | def adversarial_loss(output_frame, discriminator): 122 | fake_pred = discriminator(output_frame) 123 | loss = F.mse_loss(fake_pred, torch.ones_like(fake_pred)) 124 | return loss.requires_grad_() 125 | 126 | # @profile 127 | def cycle_consistency_loss(output_frame, source_frame, driving_frame, generator): 128 | reconstructed_source = generator(output_frame, source_frame) 129 | loss = F.l1_loss(reconstructed_source, source_frame) 130 | return loss.requires_grad_() 131 | 132 | 133 | def contrastive_loss(output_frame, source_frame, driving_frame, encoder, margin=1.0): 134 | z_out = encoder(output_frame) 135 | z_src = encoder(source_frame) 136 | z_drv = encoder(driving_frame) 137 | z_rand = torch.randn_like(z_out, requires_grad=True) 138 | 139 | pos_pairs = [(z_out, z_src), (z_out, z_drv)] 140 | neg_pairs = [(z_out, z_rand), (z_src, z_rand)] 141 | 142 | loss = torch.tensor(0.0, requires_grad=True).to(device) 143 | for pos_pair in pos_pairs: 144 | loss = loss + torch.log(torch.exp(F.cosine_similarity(pos_pair[0], pos_pair[1])) / 145 | (torch.exp(F.cosine_similarity(pos_pair[0], pos_pair[1])) + 146 | neg_pair_loss(pos_pair, neg_pairs, margin))) 147 | 148 | return loss 149 | 150 | def neg_pair_loss(pos_pair, neg_pairs, margin): 151 | loss = torch.tensor(0.0, requires_grad=True).to(device) 152 | for neg_pair in neg_pairs: 153 | loss = loss + torch.exp(F.cosine_similarity(pos_pair[0], neg_pair[1]) - margin) 154 | return loss 155 | # @profile 156 | def discriminator_loss(real_pred, fake_pred): 157 | real_loss = F.mse_loss(real_pred, torch.ones_like(real_pred)) 158 | fake_loss = F.mse_loss(fake_pred, torch.zeros_like(fake_pred)) 159 | return (real_loss + fake_loss).requires_grad_() 160 | 161 | 162 | # @profile 163 | def gaze_loss_fn(predicted_gaze, target_gaze, face_image): 164 | # Ensure face_image has shape (C, H, W) 165 | if face_image.dim() == 4 and face_image.shape[0] == 1: 166 | face_image = face_image.squeeze(0) 167 | if face_image.dim() != 3 or face_image.shape[0] not in [1, 3]: 168 | raise ValueError(f"Expected face_image of shape (C, H, W), got {face_image.shape}") 169 | 170 | # Convert face image from tensor to numpy array 171 | face_image = face_image.detach().cpu().numpy() 172 | if face_image.shape[0] == 3: # if channels are first 173 | face_image = face_image.transpose(1, 2, 0) 174 | face_image = (face_image * 255).astype(np.uint8) 175 | 176 | # Extract eye landmarks using MediaPipe 177 | results = face_mesh.process(cv2.cvtColor(face_image, cv2.COLOR_RGB2BGR)) 178 | if not results.multi_face_landmarks: 179 | return torch.tensor(0.0, requires_grad=True).to(device) 180 | 181 | eye_landmarks = [] 182 | for face_landmarks in results.multi_face_landmarks: 183 | left_eye_landmarks = [face_landmarks.landmark[idx] for idx in mp.solutions.face_mesh.FACEMESH_LEFT_EYE] 184 | right_eye_landmarks = [face_landmarks.landmark[idx] for idx in mp.solutions.face_mesh.FACEMESH_RIGHT_EYE] 185 | eye_landmarks.append((left_eye_landmarks, right_eye_landmarks)) 186 | 187 | # Compute loss for each eye 188 | loss = 0.0 189 | h, w = face_image.shape[:2] 190 | for left_eye, right_eye in eye_landmarks: 191 | # Convert landmarks to pixel coordinates 192 | left_eye_pixels = [(int(lm.x * w), int(lm.y * h)) for lm in left_eye] 193 | right_eye_pixels = [(int(lm.x * w), int(lm.y * h)) for lm in right_eye] 194 | 195 | # Create eye mask 196 | left_mask = torch.zeros((1, h, w), requires_grad=True).to(device) 197 | right_mask = torch.zeros((1, h, w), requires_grad=True).to(device) 198 | cv2.fillPoly(left_mask[0].cpu().numpy(), [np.array(left_eye_pixels)], 1.0) 199 | cv2.fillPoly(right_mask[0].cpu().numpy(), [np.array(right_eye_pixels)], 1.0) 200 | 201 | # Compute gaze loss for each eye 202 | left_gaze_loss = F.mse_loss(predicted_gaze * left_mask, target_gaze * left_mask) 203 | right_gaze_loss = F.mse_loss(predicted_gaze * right_mask, target_gaze * right_mask) 204 | loss += left_gaze_loss + right_gaze_loss 205 | 206 | return loss / len(eye_landmarks) 207 | 208 | 209 | def train_base(cfg, Gbase, Dbase, dataloader): 210 | Gbase.train() 211 | Dbase.train() 212 | optimizer_G = torch.optim.AdamW(Gbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 213 | optimizer_D = torch.optim.AdamW(Dbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 214 | scheduler_G = CosineAnnealingLR(optimizer_G, T_max=cfg.training.base_epochs, eta_min=1e-6) 215 | scheduler_D = CosineAnnealingLR(optimizer_D, T_max=cfg.training.base_epochs, eta_min=1e-6) 216 | 217 | vgg19 = Vgg19().to(device) 218 | perceptual_loss_fn = nn.L1Loss().to(device) 219 | # gaze_loss_fn = MPGazeLoss(device) 220 | encoder = Encoder(input_nc=3, output_nc=256).to(device) 221 | 222 | for epoch in range(cfg.training.base_epochs): 223 | print("epoch:", epoch) 224 | for batch in dataloader: 225 | source_frames = batch['source_frames'] #.to(device) 226 | driving_frames = batch['driving_frames'] #.to(device) 227 | 228 | num_frames = len(source_frames) # Get the number of frames in the batch 229 | 230 | for idx in range(num_frames): 231 | source_frame = source_frames[idx].to(device) 232 | driving_frame = driving_frames[idx].to(device) 233 | 234 | # Train generator 235 | optimizer_G.zero_grad() 236 | output_frame = Gbase(source_frame, driving_frame) 237 | 238 | # Resize output_frame to 256x256 to match the driving_frame size 239 | output_frame = F.interpolate(output_frame, size=(256, 256), mode='bilinear', align_corners=False) 240 | 241 | 242 | # 💀 Compute losses - "losses are calculated using ONLY foreground regions" 243 | # Obtain the foreground mask for the target image 244 | foreground_mask = get_foreground_mask(source_frame) 245 | 246 | # Multiply the predicted and target images with the foreground mask 247 | masked_predicted_image = output_frame * foreground_mask 248 | masked_target_image = source_frame * foreground_mask 249 | 250 | 251 | output_vgg_features = vgg19(masked_predicted_image) 252 | driving_vgg_features = vgg19(masked_target_image) 253 | total_loss = 0 254 | 255 | for output_feat, driving_feat in zip(output_vgg_features, driving_vgg_features): 256 | total_loss = total_loss + perceptual_loss_fn(output_feat, driving_feat.detach()) 257 | 258 | loss_adversarial = adversarial_loss(masked_predicted_image, Dbase) 259 | 260 | loss_gaze = gaze_loss_fn(output_frame, driving_frame, source_frame) # 🤷 fix this 261 | # Combine the losses and perform backpropagation and optimization 262 | total_loss = total_loss + loss_adversarial + loss_gaze 263 | 264 | 265 | # Accumulate gradients 266 | loss_gaze.backward() 267 | total_loss.backward(retain_graph=True) 268 | loss_adversarial.backward() 269 | 270 | # Update generator 271 | optimizer_G.step() 272 | 273 | # Train discriminator 274 | optimizer_D.zero_grad() 275 | real_pred = Dbase(driving_frame) 276 | fake_pred = Dbase(output_frame.detach()) 277 | loss_D = discriminator_loss(real_pred, fake_pred) 278 | 279 | # Backpropagate and update discriminator 280 | loss_D.backward() 281 | optimizer_D.step() 282 | 283 | 284 | # Update learning rates 285 | scheduler_G.step() 286 | scheduler_D.step() 287 | 288 | # Log and save checkpoints 289 | if (epoch + 1) % cfg.training.log_interval == 0: 290 | print(f"Epoch [{epoch+1}/{cfg.training.base_epochs}], " 291 | f"Loss_G: {loss_gaze.item():.4f}, Loss_D: {loss_D.item():.4f}") 292 | if (epoch + 1) % cfg.training.save_interval == 0: 293 | torch.save(Gbase.state_dict(), f"Gbase_epoch{epoch+1}.pth") 294 | torch.save(Dbase.state_dict(), f"Dbase_epoch{epoch+1}.pth") 295 | 296 | def train_hr(cfg, GHR, Genh, dataloader_hr): 297 | GHR.train() 298 | Genh.train() 299 | 300 | vgg19 = Vgg19().to(device) 301 | perceptual_loss_fn = nn.L1Loss().to(device) 302 | # gaze_loss_fn = MPGazeLoss(device=device) 303 | 304 | optimizer_G = torch.optim.AdamW(Genh.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 305 | scheduler_G = CosineAnnealingLR(optimizer_G, T_max=cfg.training.hr_epochs, eta_min=1e-6) 306 | 307 | for epoch in range(cfg.training.hr_epochs): 308 | for batch in dataloader_hr: 309 | source_frames = batch['source_frames'].to(device) 310 | driving_frames = batch['driving_frames'].to(device) 311 | 312 | num_frames = len(source_frames) # Get the number of frames in the batch 313 | 314 | for idx in range(num_frames): 315 | source_frame = source_frames[idx] 316 | driving_frame = driving_frames[idx] 317 | 318 | # Generate output frame using pre-trained base model 319 | with torch.no_grad(): 320 | xhat_base = GHR.Gbase(source_frame, driving_frame) 321 | 322 | # Train high-resolution model 323 | optimizer_G.zero_grad() 324 | xhat_hr = Genh(xhat_base) 325 | 326 | 327 | # Compute losses - option 1 328 | # loss_supervised = Genh.supervised_loss(xhat_hr, driving_frame) 329 | # loss_unsupervised = Genh.unsupervised_loss(xhat_base, xhat_hr) 330 | # loss_perceptual = perceptual_loss_fn(xhat_hr, driving_frame) 331 | 332 | # option2 ? 🤷 use vgg19 as per metaportrait? 333 | # - Compute losses 334 | xhat_hr_vgg_features = vgg19(xhat_hr) 335 | driving_vgg_features = vgg19(driving_frame) 336 | loss_perceptual = 0 337 | for xhat_hr_feat, driving_feat in zip(xhat_hr_vgg_features, driving_vgg_features): 338 | loss_perceptual += perceptual_loss_fn(xhat_hr_feat, driving_feat.detach()) 339 | 340 | loss_supervised = perceptual_loss_fn(xhat_hr, driving_frame) 341 | loss_unsupervised = perceptual_loss_fn(xhat_hr, xhat_base) 342 | loss_gaze = gaze_loss_fn(xhat_hr, driving_frame) 343 | loss_G = ( 344 | cfg.training.lambda_supervised * loss_supervised 345 | + cfg.training.lambda_unsupervised * loss_unsupervised 346 | + cfg.training.lambda_perceptual * loss_perceptual 347 | + cfg.training.lambda_gaze * loss_gaze 348 | ) 349 | 350 | # Backpropagate and update high-resolution model 351 | loss_G.backward() 352 | optimizer_G.step() 353 | 354 | # Update learning rate 355 | scheduler_G.step() 356 | 357 | # Log and save checkpoints 358 | if (epoch + 1) % cfg.training.log_interval == 0: 359 | print(f"Epoch [{epoch+1}/{cfg.training.hr_epochs}], " 360 | f"Loss_G: {loss_G.item():.4f}") 361 | if (epoch + 1) % cfg.training.save_interval == 0: 362 | torch.save(Genh.state_dict(), f"Genh_epoch{epoch+1}.pth") 363 | 364 | 365 | def train_student(cfg, Student, GHR, dataloader_avatars): 366 | Student.train() 367 | 368 | optimizer_S = torch.optim.AdamW(Student.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2) 369 | 370 | scheduler_S = CosineAnnealingLR(optimizer_S, T_max=cfg.training.student_epochs, eta_min=1e-6) 371 | 372 | for epoch in range(cfg.training.student_epochs): 373 | for batch in dataloader_avatars: 374 | avatar_indices = batch['avatar_indices'].to(device) 375 | driving_frames = batch['driving_frames'].to(device) 376 | 377 | # Generate high-resolution output frames using pre-trained HR model 378 | with torch.no_grad(): 379 | xhat_hr = GHR(driving_frames) 380 | 381 | # Train student model 382 | optimizer_S.zero_grad() 383 | 384 | # Generate output frames using student model 385 | xhat_student = Student(driving_frames, avatar_indices) 386 | 387 | # Compute loss 388 | loss_S = F.mse_loss(xhat_student, xhat_hr) 389 | 390 | # Backpropagate and update student model 391 | loss_S.backward() 392 | optimizer_S.step() 393 | 394 | # Update learning rate 395 | scheduler_S.step() 396 | 397 | # Log and save checkpoints 398 | if (epoch + 1) % cfg.training.log_interval == 0: 399 | print(f"Epoch [{epoch+1}/{cfg.training.student_epochs}], " 400 | f"Loss_S: {loss_S.item():.4f}") 401 | 402 | if (epoch + 1) % cfg.training.save_interval == 0: 403 | torch.save(Student.state_dict(), f"Student_epoch{epoch+1}.pth") 404 | 405 | def main(cfg: OmegaConf) -> None: 406 | use_cuda = torch.cuda.is_available() 407 | device = torch.device("cuda" if use_cuda else "cpu") 408 | 409 | transform = transforms.Compose([ 410 | transforms.ToTensor(), 411 | transforms.Normalize([0.5], [0.5]), 412 | transforms.RandomHorizontalFlip(), 413 | transforms.ColorJitter() # as augmentation for both source and target images, we use color jitter and random flip 414 | ]) 415 | 416 | dataset = EMODataset( 417 | use_gpu=use_cuda, 418 | width=cfg.data.train_width, 419 | height=cfg.data.train_height, 420 | n_sample_frames=cfg.training.n_sample_frames, 421 | sample_rate=cfg.training.sample_rate, 422 | img_scale=(1.0, 1.0), 423 | video_dir=cfg.training.video_dir, 424 | json_file=cfg.training.json_file, 425 | transform=transform 426 | ) 427 | 428 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4) 429 | 430 | Gbase = model.Gbase() 431 | Dbase = model.Discriminator(input_nc=3).to(device) # 🤷 432 | 433 | train_base(cfg, Gbase, Dbase, dataloader) 434 | 435 | GHR = model.GHR() 436 | GHR.Gbase.load_state_dict(Gbase.state_dict()) 437 | Dhr = model.Discriminator(input_nc=3).to(device) # 🤷 438 | train_hr(cfg, GHR, Dhr, dataloader) 439 | 440 | Student = model.Student(num_avatars=100) # this should equal the number of celebs in dataset 441 | train_student(cfg, Student, GHR, dataloader) 442 | 443 | torch.save(Gbase.state_dict(), 'Gbase.pth') 444 | torch.save(GHR.state_dict(), 'GHR.pth') 445 | torch.save(Student.state_dict(), 'Student.pth') 446 | 447 | if __name__ == "__main__": 448 | config = OmegaConf.load("./configs/training/stage1-base.yaml") 449 | main(config) -------------------------------------------------------------------------------- /warp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | ''' 6 | This function converts the head pose predictions to degrees. 7 | It takes the predicted head pose tensor (pred) as input. 8 | It creates an index tensor (idx_tensor) with the same length as the head pose tensor. 9 | It performs a weighted sum of the head pose predictions multiplied by the index tensor. 10 | The result is then scaled and shifted to obtain the head pose in degrees. 11 | ''' 12 | def headpose_pred_to_degree(pred): 13 | device = pred.device 14 | idx_tensor = [idx for idx, _ in enumerate(pred)] 15 | idx_tensor = torch.FloatTensor(idx_tensor).to(device) 16 | pred = pred.squeeze() 17 | pred = torch.sum(pred * idx_tensor) * 3 - 99 18 | return pred 19 | 20 | 21 | ''' 22 | This function computes the rotation matrix based on the yaw, pitch, and roll angles. 23 | It takes the yaw, pitch, and roll angles (in degrees) as input. 24 | It converts the angles from degrees to radians using torch.deg2rad. 25 | It creates separate rotation matrices for roll, pitch, and yaw using the corresponding angles. 26 | It combines the rotation matrices using Einstein summation (torch.einsum) to obtain the final rotation matrix. 27 | ''' 28 | def get_rotation_matrix(yaw, pitch, roll): 29 | yaw = torch.deg2rad(yaw) 30 | pitch = torch.deg2rad(pitch) 31 | roll = torch.deg2rad(roll) 32 | 33 | roll = roll.unsqueeze(1) 34 | pitch = pitch.unsqueeze(1) 35 | yaw = yaw.unsqueeze(1) 36 | 37 | roll_mat = torch.zeros(roll.shape[0], 3, 3).to(roll.device) 38 | roll_mat[:, 0, 0] = torch.cos(roll) 39 | roll_mat[:, 0, 1] = -torch.sin(roll) 40 | roll_mat[:, 1, 0] = torch.sin(roll) 41 | roll_mat[:, 1, 1] = torch.cos(roll) 42 | roll_mat[:, 2, 2] = 1 43 | 44 | pitch_mat = torch.zeros(pitch.shape[0], 3, 3).to(pitch.device) 45 | pitch_mat[:, 0, 0] = torch.cos(pitch) 46 | pitch_mat[:, 0, 2] = torch.sin(pitch) 47 | pitch_mat[:, 1, 1] = 1 48 | pitch_mat[:, 2, 0] = -torch.sin(pitch) 49 | pitch_mat[:, 2, 2] = torch.cos(pitch) 50 | 51 | yaw_mat = torch.zeros(yaw.shape[0], 3, 3).to(yaw.device) 52 | yaw_mat[:, 0, 0] = torch.cos(yaw) 53 | yaw_mat[:, 0, 2] = -torch.sin(yaw) 54 | yaw_mat[:, 1, 1] = 1 55 | yaw_mat[:, 2, 0] = torch.sin(yaw) 56 | yaw_mat[:, 2, 2] = torch.cos(yaw) 57 | 58 | rot_mat = torch.einsum('bij,bjk,bkm->bim', yaw_mat, pitch_mat, roll_mat) 59 | return rot_mat 60 | 61 | 62 | 63 | ''' 64 | This function creates a coordinate grid based on the given spatial size. 65 | It takes the spatial size (spatial_size) and data type (type) as input. 66 | It creates 1D tensors (x, y, z) representing the coordinates along each dimension. 67 | It normalizes the coordinate values to the range [-1, 1]. 68 | It meshes the coordinate tensors using broadcasting to create a 3D coordinate grid. 69 | The resulting coordinate grid has shape (height, width, depth, 3), where the last dimension represents the (x, y, z) coordinates. 70 | ''' 71 | def make_coordinate_grid(spatial_size, type): 72 | d, h, w = spatial_size 73 | x = torch.arange(w).to(type) 74 | y = torch.arange(h).to(type) 75 | z = torch.arange(d).to(type) 76 | 77 | x = (2 * (x / (w - 1)) - 1) 78 | y = (2 * (y / (h - 1)) - 1) 79 | z = (2 * (z / (d - 1)) - 1) 80 | 81 | yy = y.view(-1, 1, 1).repeat(1, w, d) 82 | xx = x.view(1, -1, 1).repeat(h, 1, d) 83 | zz = z.view(1, 1, -1).repeat(h, w, 1) 84 | 85 | meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) 86 | return meshed 87 | 88 | def compute_rt_warp2(rt, v_s, inverse=False): 89 | bs, _, d, h, w = v_s.shape 90 | yaw, pitch, roll = rt['yaw'], rt['pitch'], rt['roll'] 91 | yaw = headpose_pred_to_degree(yaw) 92 | pitch = headpose_pred_to_degree(pitch) 93 | roll = headpose_pred_to_degree(roll) 94 | 95 | rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) 96 | 97 | # Invert the transformation matrix if needed 98 | if inverse: 99 | rot_mat = torch.inverse(rot_mat) 100 | 101 | rot_mat = rot_mat.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) 102 | rot_mat = rot_mat.repeat(1, d, h, w, 1, 1) 103 | 104 | identity_grid = make_coordinate_grid((d, h, w), type=v_s.type()) 105 | identity_grid = identity_grid.view(1, d, h, w, 3) 106 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) 107 | 108 | t = t.view(t.shape[0], 1, 1, 1, 3) 109 | 110 | # Rotate 111 | warp_field = torch.bmm(identity_grid.reshape(-1, 1, 3), rot_mat.reshape(-1, 3, 3)) 112 | warp_field = warp_field.reshape(identity_grid.shape) 113 | warp_field = warp_field - t 114 | 115 | return warp_field --------------------------------------------------------------------------------