├── .gitignore ├── .vscode └── launch.json ├── EMOAnimationPipeline.py ├── ExtractFrames.py ├── Net.py ├── README.md ├── _assets └── speech.wav ├── architecture.png ├── camera.py ├── configs ├── inference.yaml ├── training │ ├── stage0.yaml │ ├── stage1.yaml │ ├── stage2.yaml │ └── stage3.yaml └── unet-config.yaml ├── data ├── celebvhq_info.json └── overfit.json ├── depthwise.py ├── inference.py ├── junk ├── Animate Anyone: Consistent and Controlla ├── AttentionNotes.md ├── AudioAttention │ ├── dataset.py │ ├── inference.py │ ├── model.py │ ├── synthesize.py │ ├── synthetic_dataset │ │ ├── beep_0.png │ │ ├── beep_0.wav.wav │ │ ├── beep_1.png │ │ ├── beep_1.wav.wav │ │ ├── beep_2.png │ │ ├── beep_2.wav.wav │ │ ├── beep_3.png │ │ ├── beep_3.wav.wav │ │ ├── beep_4.png │ │ ├── beep_4.wav.wav │ │ ├── beep_5.png │ │ ├── beep_5.wav.wav │ │ ├── beep_6.png │ │ ├── beep_6.wav.wav │ │ ├── beep_7.png │ │ ├── beep_7.wav.wav │ │ ├── beep_8.png │ │ ├── beep_8.wav.wav │ │ ├── beep_9.png │ │ ├── beep_9.wav.wav │ │ ├── buzz_0.png │ │ ├── buzz_0.wav.wav │ │ ├── buzz_1.png │ │ ├── buzz_1.wav.wav │ │ ├── buzz_2.png │ │ ├── buzz_2.wav.wav │ │ ├── buzz_3.png │ │ ├── buzz_3.wav.wav │ │ ├── buzz_4.png │ │ ├── buzz_4.wav.wav │ │ ├── buzz_5.png │ │ ├── buzz_5.wav.wav │ │ ├── buzz_6.png │ │ ├── buzz_6.wav.wav │ │ ├── buzz_7.png │ │ ├── buzz_7.wav.wav │ │ ├── buzz_8.png │ │ ├── buzz_8.wav.wav │ │ ├── buzz_9.png │ │ ├── buzz_9.wav.wav │ │ ├── tick_0.png │ │ ├── tick_0.wav.wav │ │ ├── tick_1.png │ │ ├── tick_1.wav.wav │ │ ├── tick_2.png │ │ ├── tick_2.wav.wav │ │ ├── tick_3.png │ │ ├── tick_3.wav.wav │ │ ├── tick_4.png │ │ ├── tick_4.wav.wav │ │ ├── tick_5.png │ │ ├── tick_5.wav.wav │ │ ├── tick_6.png │ │ ├── tick_6.wav.wav │ │ ├── tick_7.png │ │ ├── tick_7.wav.wav │ │ ├── tick_8.png │ │ ├── tick_8.wav.wav │ │ ├── tick_9.png │ │ └── tick_9.wav.wav │ └── train.py ├── BroadcastinExample.py ├── DiffusedHeads.txt ├── EMO: Emote Portrait Alive - Generating.txt ├── EMo-write-up.txt ├── M2Ohb0FAaJU_1.mp4 ├── M2Ohb0FAaJU_1.wav ├── bla.png ├── frame_0094_debug.jpg └── pipeline.png ├── magicanimate ├── models │ ├── all.py │ ├── appearance_encoder.py │ ├── attention.py │ ├── controlnet.py │ ├── embeddings.py │ ├── motion_module.py │ ├── mutual_self_attention.py │ ├── orig_attention.py │ ├── resnet.py │ ├── stable_diffusion_controlnet_reference.py │ ├── unet.py │ ├── unet_3d_blocks.py │ └── unet_controlnet.py ├── pipelines │ ├── animation.py │ ├── context.py │ └── pipeline_animation.py └── utils │ ├── dist_tools.py │ ├── util.py │ └── videoreader.py ├── models ├── motionmodule.py └── videonet.py ├── requirements.txt ├── train_stage_1_referencenet.py ├── train_stage_2_temporal_audio.py ├── train_stage_3_speedlayers.py ├── video.py └── videonet_animatediff.py /.gitignore: -------------------------------------------------------------------------------- 1 | images_folder 2 | *.pyc 3 | temp/debug_face_mask_xgMGnhrKZQI_14-99.png 4 | junk/media/Tex/* 5 | junk/AudioAttention/synthetic_dataset/*.* 6 | junk/media/videos/depthwise/*.* 7 | junk/media/videos/bla/* 8 | junk/media/videos/* 9 | junk/media/texts/* 10 | runs/* -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: Current File", 9 | "type": "debugpy", 10 | "request": "launch", 11 | // "program": "./junk/AudioAttention/train.py", 12 | "program": "./train_stage_1_0.py", 13 | // "program": "./junk/AudioAttention/inference.py", 14 | "console": "integratedTerminal" 15 | } 16 | ] 17 | } -------------------------------------------------------------------------------- /ExtractFrames.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | import json 4 | import os 5 | 6 | # Your JSON data for the video clip 7 | clip_data = { 8 | "M2Ohb0FAaJU_1": { 9 | "ytb_id": "M2Ohb0FAaJU", 10 | "duration": {"start_sec": 81.62, "end_sec": 86.17}, 11 | "bbox": {"top": 0.0, "bottom": 0.8815, "left": 0.1964, "right": 0.6922}, 12 | "attributes": { 13 | "appearance": [0, 0, 1], # Truncated for example purposes 14 | "action": [0, 0, 0], # Truncated for example purposes 15 | "emotion": {"sep_flag": False, "labels": "neutral"} 16 | }, 17 | "version": "v0.1" 18 | } 19 | } 20 | 21 | # Define the function to extract frames 22 | def extract_frames(video_path, clip_info): 23 | # Open the video file 24 | cap = cv2.VideoCapture(video_path) 25 | fps = cap.get(cv2.CAP_PROP_FPS) # Frames per second 26 | start_frame = int(clip_info['duration']['start_sec'] * fps) 27 | end_frame = int(clip_info['duration']['end_sec'] * fps) 28 | bbox = clip_info['bbox'] 29 | 30 | # Set video to start frame 31 | cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) 32 | 33 | # Extract frames 34 | for frame_num in range(start_frame, end_frame + 1): 35 | ret, frame = cap.read() 36 | if not ret: 37 | break # Break the loop if frames cannot be read 38 | 39 | # Convert to PIL image for easier cropping 40 | frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 41 | 42 | # Calculate bounding box coordinates 43 | width, height = frame.size 44 | left = bbox['left'] * width 45 | top = bbox['top'] * height 46 | right = bbox['right'] * width 47 | bottom = bbox['bottom'] * height 48 | frame = frame.crop((left, top, right, bottom)) 49 | 50 | # Save the frame with bounding box applied 51 | frame.save(f"frame_{frame_num}.png") 52 | 53 | cap.release() 54 | 55 | 56 | def extract_and_save_frames(video_path, images_folder): 57 | # Create a subfolder with the same name as the video file (without extension) 58 | video_name = os.path.splitext(os.path.basename(video_path))[0] 59 | subfolder_path = os.path.join(images_folder, video_name) 60 | os.makedirs(subfolder_path, exist_ok=True) 61 | 62 | # Open the video file 63 | cap = cv2.VideoCapture(video_path) 64 | 65 | if not cap.isOpened(): 66 | print("Error: Could not open video.") 67 | return 68 | 69 | frame_count = 0 70 | 71 | while True: 72 | ret, frame = cap.read() 73 | if not ret: 74 | break # Exit the loop if no more frames are available 75 | 76 | frame_filename = os.path.join(subfolder_path, f"frame_{frame_count:04d}.jpg") 77 | cv2.imwrite(frame_filename, frame) 78 | frame_count += 1 79 | 80 | cap.release() 81 | print(f"Total frames extracted: {frame_count}") 82 | 83 | # Assuming your video is named 'M2Ohb0FAaJU_1.mp4' and located in the current directory 84 | video_path = 'M2Ohb0FAaJU_1.mp4' 85 | extract_and_save_frames(video_path,'./images_folder') 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # EMO: Emote Portrait Alive - 4 | using chatgpt to reverse engineer code from HumanAIGC/EMO white paper. Work in progress - WIP 5 | 6 | 7 | 8 | # UPDATE - AniPortrait achieves near EMO like results. 9 | https://github.com/Zejun-Yang/AniPortrait 10 | Here's a sample I made using the vid2vid - works fine. 11 | https://drive.google.com/file/d/1HaHPZbllOVPhbGkvV3aHLtcEew9CZGUV/view?usp=sharing 12 | ![Image](https://private-user-images.githubusercontent.com/289994/317962559-265b28ae-f16f-47ba-883d-4d3ae82d176d.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTE2OTA4NDAsIm5iZiI6MTcxMTY5MDU0MCwicGF0aCI6Ii8yODk5OTQvMzE3OTYyNTU5LTI2NWIyOGFlLWYxNmYtNDdiYS04ODNkLTRkM2FlODJkMTc2ZC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjQwMzI5JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI0MDMyOVQwNTM1NDBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT01ZGIyYzFiMGM3MWUyNjQ0NjA1ZGQ4OThiZDBlYzFmMDM5NDVjOGExODQ2NGFjNWE2YTJhNzcxMzA5YzM4MDE2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.xQoPWWmnwcxi3SQ5Cy9Y005SPY-1zMNyLMkRXh5J0g8) 13 | 14 | ![Image](https://github.com/johndpope/Emote-hack/assets/289994/0d758a3a-841f-4849-b58c-439dda05c9a7) 15 | 16 | 17 | https://arxiv.org/html/2402.17485v1 18 | 19 | 20 | ## WARNING - the repo is work in progress. If you're here to train the model - come back later. Classes here are more like placeholders / building blocks. 21 | The heavy lifting now is implementing the denoise of unet/ integrating attentions. 22 | 23 | 24 | 25 | ## Background papers to research / study 26 | - **AnimateDiff** (no training code?) 27 | - **MagicAnimate** (no training code?) 28 | - **AnimateAnyone** (no code) 29 | - **Moore-AnimateAnyone** (training code) 30 | There's training code for 3 stages 31 | https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/train_stage_1.py 32 | - **AnimateAnyone** - https://github.com/jimmyl02/animate/tree/main/animate-anyone 33 | 3 training stages here 34 | https://github.com/jimmyl02/animate/tree/main/animate-anyone 35 | - **DiffusedHeads** - (no training code) https://github.com/MStypulkowski/diffused-heads 36 | 37 | While this is using poseguider - it's not hard to see a dwpose / facial driving the animation. https://www.reddit.com/r/StableDiffusion/comments/1281iva/new_controlnet_face_model/?rdt=50313&onetap_auto=true 38 | 39 | 40 | These papers build on previous code. 41 | 42 | 43 | 44 | 45 | Claude3 has been the best to use to understand the paper. 46 | It's possible to upload the text of paper / the diagram / and throw all the code at it. It has 200k context size. 47 | 48 | # Model Architecture: 49 | Almost all the models are here 50 | https://github.com/johndpope/Emote-hack/blob/main/Net.py 51 | 52 | I'm exploring audio attention in junk folder. 53 | There's a synthesize class that will generate both sounds / images. 54 | this paper is supposed to train the audio attention so if it gets 55 | a specific sound - it would correspond to facial movements. 56 | This needs further exploring / testing. 57 | ./junk/AudioAttention/synthesize.py 58 | ideally the network would take a sound (wav2vec stuff) - and show an facial expression. Right? Facelocator is drafted - could use extra eyes - the paper is saying the face region is a M mask for all the video frames. 59 | 60 | 61 | ## Face Locator: 62 | The face locator is a separate module that learns to detect and localize the face region in a single input image.It takes a reference image as input and outputs the corresponding face region mask.(DRAFTED - train_stage_0.py) 63 | UPDATE - I think we can substitute this work for Alibaba's existing trained model (6.8gb) to drop in replace and provide mask conditioning https://github.com/johndpope/Emote-hack/issues/28 64 | 65 | 66 | ## Speed Encoder: 67 | The speed encoder takes the audio waveform as input and extracts speed embeddings. 68 | The speed embeddings encode the velocity and motion information derived from the audio. 69 | 70 | ## Backbone Network (Audio-Driven Generator): 71 | The backbone network is an audio-driven generator that takes the following inputs: 72 | The face region image extracted by the face locator from the reference image. 73 | The speed embeddings obtained from the speed encoder. 74 | Noisy latents generated from the face region image. 75 | The backbone network generates the output video frames conditioned on the audio and the reference image. 76 | It incorporates the speed embeddings to guide the motion and velocity of the generated frames. 77 | 78 | 79 | # Inference Process: 80 | 81 | ## Reference Image: 82 | During inference, the user provides a single reference image of the desired character. 83 | ## Face Locator: 84 | The face locator is applied to the reference image to detect and extract the face region. 85 | The face region mask is obtained from the face locator. 86 | ## Audio Waveform: 87 | The user provides an audio waveform as input, which can be a speech or any other audio signal. 88 | ## Speed Encoder: 89 | The audio waveform is passed through the speed encoder to obtain the speed embeddings. 90 | The speed embeddings encode the velocity and motion information derived from the audio. 91 | 92 | ## Backbone Network (Audio-Driven Generator): 93 | The extracted face region image, speed embeddings, and noisy latents are fed into the backbone network. 94 | The backbone network generates the output video frames conditioned on the audio and the reference image. 95 | The speed embeddings guide the motion and velocity of the generated frames, ensuring synchronization with the audio. 96 | 97 | # Training Process: 98 | 99 | ## Face Locator: 100 | The face locator is trained separately using a dataset of images with corresponding face region annotations or masks. 101 | 102 | ## Speed Encoder: 103 | The speed encoder is trained using a dataset of audio waveforms and corresponding velocity or motion annotations. 104 | ## Backbone Network (Audio-Driven Generator): 105 | The backbone network is trained using a dataset consisting of reference images, audio waveforms, and corresponding ground truth video frames. 106 | During training, the face locator extracts the face regions from the reference images, and the speed encoder provides the speed embeddings from the audio waveforms. 107 | The backbone network learns to generate video frames that match the ground truth frames while being conditioned on the audio and reference image. 108 | 109 | In this rearchitected model, the inference process takes a single reference image and an audio waveform as input, and the model generates the output video frames conditioned on the audio and the reference image. The face locator and speed encoder are used to extract the necessary information from the inputs, which is then fed into the backbone network for generating the output video. 110 | 111 | 112 | 113 | ## Training Data (☢️ dont need this yet.) 114 | 115 | - **Total Videos:** 36,000 facial videos 116 | - **Total Size:** 40GB 117 | 118 | 119 | ### Training Strategy 120 | for now - to simplify problem - we can use a single video the ./data/M2Ohb0FAaJU_1.mp4. We don't need the 40gb of videos. 121 | Once all stages are trained on this single video (by overfitting this single use case) we should be able to give EMO the first frame + audio and it should produce a video with head moving. 122 | 123 | 124 | 125 | ### Torrent Download 126 | 127 | You can download the dataset via the provided magnet link or by visiting [Academic Torrents](https://academictorrents.com/details/843b5adb0358124d388c4e9836654c246b988ff4). 128 | 129 | ```plaintext 130 | magnet:?xt=urn:btih:843b5adb0358124d388c4e9836654c246b988ff4&dn=CelebV-HQ&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=https%3A%2F%2Fipv6.academictorrents.com%2Fannounce.php 131 | ``` 132 | 133 | 134 | 135 | ### Sample Video (Cropped & Trimmed) 136 | 137 | Note: The sample includes rich tagging. For more details, see `./data/test.json`. 138 | 139 | [![Watch the Sample Video](./junk/frame_0094_debug.jpg)](./junk/M2Ohb0FAaJU_1.mp4) 140 | 141 | 142 | 143 | ### Models / architecture 144 | (flux) 145 | 146 | 147 | 148 | ```javascript 149 | - ✅ ReferenceNet 150 | - __init__(self, config, reference_unet, denoising_unet, vae, dtype) 151 | - forward(self, reference_image, motion_features, timesteps) 152 | 153 | - ✅ SpeedEncoder 154 | - __init__(num_speed_buckets, speed_embedding_dim) 155 | - get_bucket_centers() 156 | - get_bucket_radii() 157 | - encode_speed(head_rotation_speed) 158 | - forward(head_rotation_speeds) 159 | 160 | - CrossAttentionLayer 161 | - __init__(feature_dim) 162 | - forward(latent_code, audio_features) 163 | 164 | - AudioAttentionLayers 165 | - __init__(feature_dim, num_layers) 166 | - forward(latent_code, audio_features) 167 | 168 | -✅ EMOModel 169 | - __init__(vae, image_encoder, config) 170 | - forward(noisy_latents, timesteps, ref_image, motion_frames, audio_features, head_rotation_speeds) 171 | 172 | -✅ Wav2VecFeatureExtractor 173 | - __init__(model_name, device) 174 | - extract_features_from_mp4(video_path, m, n) 175 | - extract_features_for_frame(video_path, frame_index, m) 176 | 177 | - AudioFeatureModel 178 | - __init__(input_size, output_size) 179 | - forward(x) 180 | 181 | -✅ FaceLocator 182 | - __init__() 183 | - forward(images) 184 | 185 | -✅ FaceHelper 186 | - __init__() 187 | - __del__() 188 | - generate_face_region_mask(frame_image, video_id, frame_idx) 189 | - generate_face_region_mask_np_image(frame_np, video_id, frame_idx, padding) 190 | - generate_face_region_mask_pil_image(frame_image, video_id, frame_idx) 191 | - calculate_pose(face2d) 192 | - draw_axis(img, yaw, pitch, roll, tdx, tdy, size) 193 | - get_head_pose(image_path) 194 | - get_head_pose_velocities_at_frame(video_reader, frame_index, n_previous_frames) 195 | 196 | - EmoVideoReader 197 | - __init__(pixel_transform, cond_transform, state) 198 | - augmentedImageAtFrame(index) 199 | - augmentation(images, transform, state) 200 | 201 | -✅ EMODataset 202 | - __init__(use_gpu, data_dir, sample_rate, n_sample_frames, width, height, img_scale, img_ratio, video_dir, drop_ratio, json_file, stage, transform) 203 | - __len__() 204 | - augmentation(images, transform, state) 205 | - __getitem__(index) 206 | 207 | ``` 208 | 209 | 210 | ```javascript 211 | - EMOAnimationPipeline (copied from magicanimate) 212 | - has some training code 213 | ``` 214 | 215 | magicanimate code - it has custom blocks for unet - maybe very useful when wiring up the attentions in unet. 216 | ```javascript 217 | - EMOAnimationPipeline (copied from magicanimate) 218 | - has some training code / this should not need text encoder / clip to aling with EMO paper. 219 | ``` 220 | 221 | -------------------------------------------------------------------------------- /_assets/speech.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/_assets/speech.wav -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/architecture.png -------------------------------------------------------------------------------- /camera.py: -------------------------------------------------------------------------------- 1 | # **************************************************************************** # 2 | # # 3 | # ::: :::::::: # 4 | # camera.py :+: :+: :+: # 5 | # +:+ +:+ +:+ # 6 | # By: taston +#+ +:+ +#+ # 7 | # +#+#+#+#+#+ +#+ # 8 | # Created: 2023/04/25 15:41:23 by taston #+# #+# # 9 | # Updated: 2023/09/01 13:41:21 by taston ### ########.fr # 10 | # # 11 | # **************************************************************************** # 12 | 13 | import numpy as np 14 | import cv2 15 | from datetime import datetime 16 | from video import Video 17 | 18 | class Camera: 19 | """ 20 | A class used to represent a Camera. 21 | 22 | ... 23 | 24 | Attributes 25 | ---------- 26 | focal_length : float 27 | float representing the focal length of the Camera in mm 28 | internal_matrix : ndarray 29 | array representing the Camera's intrinsic parameters 30 | distortion_matrix : ndarray 31 | array representing the Camera's lens distortion parameters 32 | calibrator : Calibrator 33 | Calibrator object used for camera calibration 34 | calibrated : bool 35 | bool for quick checking if camera has been calibrated 36 | video : Video 37 | Video object where the footage has been shot using this Camera 38 | 39 | Methods 40 | ------- 41 | calibrate(checkerboard=(9,6), video=Video()): 42 | Performs calibration on the camera 43 | """ 44 | def __init__(self): 45 | """ 46 | Parameters 47 | ---------- 48 | ... 49 | """ 50 | 51 | width = 1280 52 | height = 720 53 | self.focal_length = height * 1.28 54 | # self.focal_length = 5000 55 | self.internal_matrix = np.array([[self.focal_length, 0, width/2], 56 | [0, self.focal_length, height/2], 57 | [0, 0, 1]]) 58 | self.distortion_matrix = np.zeros((4, 1), dtype=np.float64) 59 | self.calibrated = False 60 | 61 | 62 | def calibrate(self, checkerboard=(9,6), video=Video(), show=True): 63 | """Creates a calibrator object and calibrates the Camera. 64 | 65 | If arguments checkerboard and video aren't passed in, the 66 | default checkerboard pattern and an empty video are used. 67 | 68 | Parameters 69 | ---------- 70 | checkerboard : tuple, optional 71 | Checkerboard pattern used in camera calibration (default is 9x6) 72 | video : Video, optional 73 | Video used to calibrate camera 74 | """ 75 | 76 | self.video = video 77 | self.calibrator = Calibrator(checkerboard, self.video, show) 78 | self.calibrated = True 79 | 80 | self.internal_matrix, self.distortion_matrix = self.calibrator.matrix, self.calibrator.distortion 81 | 82 | return self 83 | 84 | 85 | class Checkerboard: 86 | """ 87 | A class used to represent a calibration Checkerboard 88 | 89 | ... 90 | 91 | Attributes 92 | ---------- 93 | dimensions : tuple 94 | tuple of checkerboard pattern dimensions 95 | min_points : int 96 | integer threshold of minimum detected points for 97 | checkerboard to be considered found 98 | objectp3d : ndarray 99 | array of checkerboard points in three dimensions 100 | threedpoints : list 101 | list of checkerboard points in three dimensions 102 | for each frame where a checkerboard is found 103 | twodpoints : list 104 | list of detected checkerboard points in two 105 | dimensions for each frame 106 | 107 | Methods 108 | ------- 109 | get_corners(gray_frame) 110 | Finds checkerboard corners in a given grayscale frame 111 | """ 112 | def __init__(self, dimensions = (9,6)): 113 | """ 114 | Parameters 115 | ---------- 116 | dimensions : tuple, optional 117 | Checkerboard pattern used in camera calibration (default is 9x6) 118 | """ 119 | 120 | print('Checkerboard created') 121 | self.dimensions = dimensions 122 | self.min_points = 50 123 | self.twodpoints = [] 124 | self.threedpoints = [] 125 | self.objectp3d = np.zeros((1, self.dimensions[0] 126 | * self.dimensions[1], 127 | 3), np.float32) 128 | self.objectp3d[0, :, :2] = np.mgrid[0:self.dimensions[0], 129 | 0:self.dimensions[1]].T.reshape(-1, 2) 130 | 131 | def get_corners(self, gray_frame): 132 | """ 133 | Looks for checkerboard corners in a given grayscale video 134 | frame. 135 | 136 | Parameters 137 | ---------- 138 | gray_frame : ndarray 139 | ndarray representing grayscale frame from video 140 | 141 | Returns 142 | ------- 143 | ret : bool 144 | bool representing if corner search was successful 145 | corners : ndarray 146 | ndarray containing coordinates of corners 147 | """ 148 | 149 | ret, corners = cv2.findChessboardCorners( 150 | gray_frame, self.dimensions, 151 | cv2.CALIB_CB_ADAPTIVE_THRESH 152 | + cv2.CALIB_CB_FAST_CHECK + 153 | cv2.CALIB_CB_NORMALIZE_IMAGE) 154 | 155 | return ret, corners 156 | 157 | 158 | class Calibrator: 159 | """ 160 | A class used to represent a camera Calibrator 161 | 162 | ... 163 | 164 | Attributes 165 | ---------- 166 | checkerboard : Checkerboard 167 | 168 | criteria : tuple 169 | tuple of criteria for successful camera calibration 170 | distortion : ndarray 171 | ndarray of distortion parameters 172 | frame : ndarray 173 | ndarray representing video frame 174 | gray_frame : ndarray 175 | ndarray representing grayscale video frame 176 | matrix : ndarray 177 | ndarray representing camera intrinsic matrix 178 | r_vecs : ndarray 179 | ndarray of rotational vectors 180 | t_vecs : ndarray 181 | ndarray of translation vectors 182 | 183 | Methods 184 | ------- 185 | calibrate() 186 | Perform camera calibration process 187 | draw_corners(corners) 188 | Draw checkerboard corners on video frame 189 | save_outputs() 190 | Save camera parameters to csv files 191 | """ 192 | 193 | def __init__(self, checkerboard, video=Video(), show=True): 194 | """ 195 | Parameters 196 | ---------- 197 | checkerboard : tuple 198 | tuple representing Checkerboard pattern 199 | video : Video, optional 200 | Video used for camera calibration. If no video specified 201 | an empty video will be attempted. 202 | """ 203 | timestamp = datetime.now().strftime("%H:%M:%S") 204 | self.show = show 205 | self.video = video 206 | print('-'*120) 207 | print('{:<100} {:>19}'.format(f'Creating Calibrator object for video {self.video.filename}:', timestamp)) 208 | print('-'*120) 209 | print(self.video) 210 | self.checkerboard = Checkerboard(checkerboard) 211 | print(f'Checkerboard dimensions: {self.checkerboard.dimensions[0]} x {self.checkerboard.dimensions[1]}') 212 | self.criteria = (cv2.TERM_CRITERIA_EPS + 213 | cv2.TERM_CRITERIA_MAX_ITER, 30, 0.001) 214 | self.calibrate() 215 | # self.save_outputs() 216 | timestamp = datetime.now().strftime("%H:%M:%S") 217 | print('-'*120) 218 | print('{:<100} {:>19}'.format(f'Calibrator object complete!', timestamp)) 219 | print('-'*120) 220 | 221 | def calibrate(self): 222 | """ 223 | Performs the camera calibration procedure outlined here: 224 | 225 | https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html 226 | """ 227 | timestamp = datetime.now().strftime("%H:%M:%S") 228 | self.video.create_writer() 229 | print('Displaying video...') 230 | while True: 231 | ret, self.frame = self.video.cap.read() 232 | frame_number = int(self.video.cap.get(cv2.CAP_PROP_POS_FRAMES)) 233 | self.gray_frame = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY) 234 | ret, corners = self.checkerboard.get_corners(self.gray_frame) 235 | if ret: 236 | complete, image = self.draw_corners(corners) 237 | if complete: break 238 | 239 | if self.show == True: 240 | cv2.imshow('Calibrating...', self.frame) 241 | self.video.writer.write(self.frame) 242 | k = cv2.waitKey(1) 243 | if k == 27: 244 | self.video.cap.release() 245 | self.video.writer.release() 246 | cv2.destroyAllWindows() 247 | break 248 | h, w = image.shape[:2] 249 | 250 | # Perform camera calibration by given threedpoints and twodpoints 251 | ret, self.matrix, self.distortion, self.r_vecs, self.t_vecs = cv2.calibrateCamera(self.checkerboard.threedpoints, 252 | self.checkerboard.twodpoints, 253 | self.gray_frame.shape[::-1], None, None) 254 | print(f'Number of frames used for calibration: {frame_number}') 255 | 256 | return self 257 | 258 | def draw_corners(self, corners): 259 | ''' 260 | Draws corners of checkerboard onto frame to verify calibration is working 261 | 262 | Parameters 263 | ---------- 264 | corners : ndarray 265 | ndarray of the corners found for a given frame 266 | 267 | Returns 268 | ------- 269 | complete : bool 270 | bool representing whether search for corners is complete 271 | frame : ndarray 272 | new video frame with corners drawn 273 | ''' 274 | complete = False 275 | 276 | self.checkerboard.threedpoints.append(self.checkerboard.objectp3d) 277 | # Refining pixel coordinates for given 2d points. 278 | corners2 = cv2.cornerSubPix( 279 | self.gray_frame, corners, self.checkerboard.dimensions, (-1, -1), self.criteria) 280 | self.checkerboard.twodpoints.append(corners2) 281 | # When we have minimum number of data points, stop: 282 | if len(self.checkerboard.twodpoints) > self.checkerboard.min_points: 283 | self.video.cap.release() 284 | self.video.writer.release() 285 | cv2.destroyAllWindows() 286 | complete=True 287 | 288 | # Draw and display the corners: 289 | frame = cv2.drawChessboardCorners(self.frame, 290 | self.checkerboard.dimensions, 291 | corners2, True) 292 | 293 | return complete, frame 294 | 295 | def save_outputs(self): 296 | """ 297 | Saves matrices to csv 298 | """ 299 | timestamp = datetime.now().strftime("%H:%M:%S") 300 | print('Saving outputs...') 301 | from numpy import savetxt 302 | savetxt('camera_matrix.csv', self.matrix, delimiter=',') 303 | savetxt('camera_distortion.csv', self.distortion, delimiter=',') 304 | 305 | return -------------------------------------------------------------------------------- /configs/inference.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | unet_use_cross_frame_attention: false 3 | unet_use_temporal_attention: false 4 | use_motion_module: false 5 | motion_module_resolutions: 6 | - 1 7 | - 2 8 | - 4 9 | - 8 10 | motion_module_mid_block: false 11 | motion_module_decoder_only: false 12 | motion_module_type: Vanilla 13 | motion_module_kwargs: 14 | num_attention_heads: 8 15 | num_transformer_block: 1 16 | attention_block_types: 17 | - Temporal_Self 18 | - Temporal_Self 19 | temporal_position_encoding: true 20 | temporal_position_encoding_max_len: 24 21 | temporal_attention_dim_div: 1 22 | 23 | noise_scheduler_kwargs: 24 | beta_start: 0.00085 25 | beta_end: 0.012 26 | beta_schedule: "linear" -------------------------------------------------------------------------------- /configs/training/stage0.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 1 3 | train_width: 512 4 | train_height: 512 5 | sample_margin: 30 6 | sample_rate: 4 7 | n_sample_frames: 1 8 | training: 9 | batch_size: 2 10 | num_workers: 0 11 | learning_rate: 1.0e-5 12 | num_epochs: 2 13 | use_gpu_video_tensor: True 14 | video_data_dir: '/home/oem/Downloads/CelebV-HQ/celebvhq/35666' 15 | solver: 16 | gradient_accumulation_steps: 1 17 | mixed_precision: 'fp16' 18 | enable_xformers_memory_efficient_attention: True 19 | gradient_checkpointing: False 20 | max_train_steps: 30000 21 | max_grad_norm: 1.0 22 | # lr 23 | learning_rate: 1.0e-5 24 | scale_lr: False 25 | lr_warmup_steps: 1 26 | lr_scheduler: 'constant' 27 | 28 | # optimizer 29 | use_8bit_adam: True 30 | adam_beta1: 0.9 31 | adam_beta2: 0.999 32 | adam_weight_decay: 1.0e-2 33 | adam_epsilon: 1.0e-8 34 | 35 | val: 36 | validation_steps: 200 37 | 38 | 39 | noise_scheduler_kwargs: 40 | num_train_timesteps: 1000 41 | beta_start: 0.00085 42 | beta_end: 0.012 43 | beta_schedule: "scaled_linear" 44 | steps_offset: 1 45 | clip_sample: false 46 | 47 | base_model_path: './pretrained_weights/sd-image-variations-diffusers' 48 | vae_model_path: './pretrained_weights/sd-vae-ft-mse' 49 | image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder' 50 | controlnet_openpose_path: './pretrained_weights/control_v11p_sd15_openpose/diffusion_pytorch_model.bin' 51 | 52 | weight_dtype: 'fp16' # [fp16, fp32] 53 | uncond_ratio: 0.1 54 | noise_offset: 0.05 55 | snr_gamma: 5.0 56 | enable_zero_snr: True 57 | 58 | 59 | seed: 12580 60 | resume_from_checkpoint: '' 61 | checkpointing_steps: 1000 62 | save_model_epoch_interval: 5 63 | exp_name: 'stage1' 64 | output_dir: './exp_output' -------------------------------------------------------------------------------- /configs/training/stage1.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 4 3 | num_epochs: 100 4 | learning_rate: 1e-4 5 | num_workers: 4 6 | log_every: 100 7 | save_every: 10 8 | checkpoint_dir: "checkpoints/stage1" 9 | 10 | data: 11 | data_dir: "data" 12 | video_dir: "videos" 13 | json_file: "metadata.json" 14 | train_width: 512 15 | train_height: 512 -------------------------------------------------------------------------------- /configs/training/stage2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | frame_dim: 1024 3 | audio_dim: 768 4 | num_heads: 8 5 | temporal_layers: 4 6 | pretrained_path: "checkpoints/stage1/latest.pt" 7 | reference_net_path: "checkpoints/stage1/reference_net_latest.pt" 8 | 9 | data: 10 | data_dir: "data" 11 | video_dir: "videos" 12 | json_file: "metadata.json" 13 | train_width: 512 14 | train_height: 512 15 | num_frames: 8 16 | audio_ctx_frames: 2 17 | sample_rate: 16000 18 | 19 | training: 20 | batch_size: 2 21 | num_epochs: 100 22 | learning_rate: 1e-5 23 | num_workers: 4 24 | log_every: 100 25 | save_every: 10 26 | checkpoint_dir: "checkpoints/stage2" 27 | device: "cuda" 28 | mixed_precision: true -------------------------------------------------------------------------------- /configs/training/stage3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | stage2_checkpoint: "checkpoints/stage2/best.pt" 3 | num_speed_buckets: 9 4 | embed_dim: 1024 5 | face_channels: 1 6 | pretrained_path: "checkpoints/stage2/latest.pt" 7 | 8 | data: 9 | data_dir: "data" 10 | video_dir: "videos" 11 | json_file: "metadata.json" 12 | train_width: 512 13 | train_height: 512 14 | num_frames: 8 15 | audio_ctx_frames: 2 16 | sample_rate: 16000 17 | 18 | training: 19 | batch_size: 2 20 | num_epochs: 50 21 | learning_rate: 1e-5 22 | num_workers: 4 23 | log_every: 100 24 | save_every: 5 25 | checkpoint_dir: "checkpoints/stage3" 26 | log_dir: "logs" 27 | device: "cuda" 28 | mixed_precision: true 29 | use_wandb: true 30 | face_loss_weight: 0.5 31 | 32 | evaluation: 33 | eval_batch_size: 1 34 | eval_frequency: 1 -------------------------------------------------------------------------------- /configs/unet-config.yaml: -------------------------------------------------------------------------------- 1 | denoising_unet_config: 2 | v2: 3 | act_fn: silu 4 | attention_head_dim: [5, 10, 20, 20] 5 | block_out_channels: [320, 640, 1280, 1280] 6 | center_input_sample: False 7 | cross_attention_dim: 1024 8 | down_block_types: ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] 9 | downsample_padding: 1 10 | dual_cross_attention: False 11 | flip_sin_to_cos: True 12 | freq_shift: 0 13 | in_channels: 4 14 | layers_per_block: 2 15 | mid_block_scale_factor: 1 16 | norm_eps: 1e-05 17 | norm_num_groups: 4 18 | num_class_embeds: null 19 | only_cross_attention: False 20 | out_channels: 4 21 | sample_size: 96 22 | up_block_types: ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] 23 | use_linear_projection: True 24 | upcast_attention: True 25 | class_embed_type: null 26 | resnet_time_scale_shift: default 27 | projection_class_embeddings_input_dim: null 28 | default: 29 | act_fn: silu 30 | attention_head_dim: 8 31 | block_out_channels: [320, 640, 1280, 1280] 32 | center_input_sample: False 33 | cross_attention_dim: 768 34 | down_block_types: ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] 35 | downsample_padding: 1 36 | flip_sin_to_cos: True 37 | freq_shift: 0 38 | in_channels: 4 39 | layers_per_block: 2 40 | mid_block_scale_factor: 1 41 | norm_eps: 1e-05 42 | norm_num_groups: 4 43 | out_channels: 4 44 | sample_size: 64 45 | up_block_types: ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] 46 | only_cross_attention: False 47 | use_linear_projection: False 48 | class_embed_type: null 49 | num_class_embeds: null 50 | upcast_attention: False 51 | resnet_time_scale_shift: default 52 | projection_class_embeddings_input_dim: null 53 | 54 | 55 | 56 | reference_unet_config: 57 | v2: 58 | act_fn: silu 59 | attention_head_dim: [5, 10, 20, 20] 60 | block_out_channels: [320, 640, 1280, 1280] 61 | center_input_sample: False 62 | cross_attention_dim: 1024 63 | down_block_types: ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] 64 | downsample_padding: 1 65 | dual_cross_attention: False 66 | flip_sin_to_cos: True 67 | freq_shift: 0 68 | in_channels: 4 69 | layers_per_block: 2 70 | mid_block_scale_factor: 1 71 | norm_eps: 1e-05 72 | norm_num_groups: 4 73 | num_class_embeds: null 74 | only_cross_attention: False 75 | out_channels: 4 76 | sample_size: 96 77 | up_block_types: ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] 78 | use_linear_projection: True 79 | upcast_attention: True 80 | class_embed_type: null 81 | resnet_time_scale_shift: default 82 | projection_class_embeddings_input_dim: null 83 | default: 84 | act_fn: silu 85 | attention_head_dim: 8 86 | block_out_channels: [320, 640, 1280, 1280] 87 | center_input_sample: False 88 | cross_attention_dim: 768 89 | down_block_types: ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] 90 | downsample_padding: 1 91 | flip_sin_to_cos: True 92 | freq_shift: 0 93 | in_channels: 4 94 | layers_per_block: 2 95 | mid_block_scale_factor: 1 96 | norm_eps: 1e-05 97 | norm_num_groups: 4 98 | out_channels: 4 99 | sample_size: 64 100 | up_block_types: ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] 101 | only_cross_attention: False 102 | use_linear_projection: False 103 | class_embed_type: null 104 | num_class_embeds: null 105 | upcast_attention: False 106 | resnet_time_scale_shift: default 107 | projection_class_embeddings_input_dim: null -------------------------------------------------------------------------------- /data/overfit.json: -------------------------------------------------------------------------------- 1 | {"meta_info": {"appearance_mapping": ["blurry", "male", "young", "chubby", "pale_skin", "rosy_cheeks", "oval_face", "receding_hairline", "bald", "bangs", "black_hair", "blonde_hair", "gray_hair", "brown_hair", "straight_hair", "wavy_hair", "long_hair", "arched_eyebrows", "bushy_eyebrows", "bags_under_eyes", "eyeglasses", "sunglasses", "narrow_eyes", "big_nose", "pointy_nose", "high_cheekbones", "big_lips", "double_chin", "no_beard", "5_o_clock_shadow", "goatee", "mustache", "sideburns", "heavy_makeup", "wearing_earrings", "wearing_hat", "wearing_lipstick", "wearing_necklace", "wearing_necktie", "wearing_mask"], "action_mapping": ["blow", "chew", "close_eyes", "cough", "cry", "drink", "eat", "frown", "gaze", "glare", "head_wagging", "kiss", "laugh", "listen_to_music", "look_around", "make_a_face", "nod", "play_instrument", "read", "shake_head", "shout", "sigh", "sing", "sleep", "smile", "smoke", "sneer", "sneeze", "sniff", "talk", "turn", "weep", "whisper", "wink", "yawn"]}, "clips": {"M2Ohb0FAaJU_1": {"ytb_id": "M2Ohb0FAaJU", "duration": {"start_sec": 81.62, "end_sec": 86.17}, "bbox": {"top": 0.0, "bottom": 0.8815, "left": 0.1964, "right": 0.6922}, "attributes": {"appearance": [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], "action": [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], "emotion": {"sep_flag": false, "labels": "neutral"}}, "version": "v0.1"}}} -------------------------------------------------------------------------------- /depthwise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffusers import UNet2DConditionModel 4 | import yaml 5 | # Define a custom depthwise separable convolutional block 6 | class DepthwiseSeparableConv2d(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 8 | super(DepthwiseSeparableConv2d, self).__init__() 9 | self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias) 10 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=bias) 11 | 12 | def forward(self, x): 13 | x = self.depthwise(x) 14 | x = self.pointwise(x) 15 | return x 16 | 17 | # Create a custom U-Net model with depthwise separable convolutions 18 | class CustomUNet(UNet2DConditionModel): 19 | def __init__(self, **kwargs): 20 | super().__init__(**kwargs) 21 | self.apply_depthwise_separable_conv() 22 | 23 | def apply_depthwise_separable_conv(self): 24 | for module in self.modules(): 25 | if isinstance(module, nn.Conv2d): 26 | depthwise_separable_conv = DepthwiseSeparableConv2d( 27 | module.in_channels, 28 | module.out_channels, 29 | module.kernel_size, 30 | module.stride, 31 | module.padding, 32 | module.bias is not None, 33 | ) 34 | module = depthwise_separable_conv 35 | 36 | 37 | def convert_conv_to_depthwise_separable(conv_layer, bias=True): 38 | """ 39 | Converts a regular convolutional layer to a depthwise separable convolutional layer. 40 | The weights and biases of the new layer are initialized with the weights and biases 41 | of the original convolutional layer. 42 | """ 43 | in_channels = conv_layer.in_channels 44 | out_channels = conv_layer.out_channels 45 | kernel_size = conv_layer.kernel_size 46 | stride = conv_layer.stride 47 | padding = conv_layer.padding 48 | dilation = conv_layer.dilation 49 | 50 | depthwise_separable_conv = DepthwiseSeparableConv2d( 51 | in_channels, out_channels, kernel_size, stride, padding, bias 52 | ) 53 | 54 | # Initialize the depthwise convolution weights 55 | depthwise_separable_conv.depthwise.weight.data = conv_layer.weight.data.clone() 56 | 57 | # Initialize the pointwise convolution weights 58 | pointwise_weight = torch.sum(conv_layer.weight.data, dim=1, keepdim=True) 59 | depthwise_separable_conv.pointwise.weight.data = pointwise_weight.clone() 60 | 61 | # Initialize the biases 62 | if bias: 63 | depthwise_separable_conv.depthwise.bias.data = conv_layer.bias.data.clone() 64 | depthwise_separable_conv.pointwise.bias.data = conv_layer.bias.data.clone() 65 | 66 | return depthwise_separable_conv 67 | 68 | reference_unet = UNet2DConditionModel.from_pretrained( 69 | config.pretrained_base_model_path, 70 | subfolder="unet", 71 | ).to(dtype=weight_dtype, device="cuda") 72 | 73 | 74 | # Load the pre-trained model 75 | # pretrained_model = YourPretrainedModel(...) 76 | # pretrained_model.load_state_dict(torch.load("path/to/pretrained_weights.pth")) 77 | 78 | # Convert the convolutional layers to depthwise separable convolutions 79 | for module in pretrained_model.modules(): 80 | if isinstance(module, nn.Conv2d): 81 | depthwise_separable_conv = convert_conv_to_depthwise_separable(module) 82 | module = depthwise_separable_conv 83 | 84 | 85 | 86 | # Load the YAML configuration file 87 | with open('./configs/config.yaml', 'r') as file: 88 | config = yaml.safe_load(file) 89 | 90 | v2 = False # SD 2.1 91 | # Access the reference_unet_config based on args.v2 92 | if v2: 93 | unet_config = config['reference_unet_config']['v2'] 94 | denoise_unet_config = config['denoising_unet_config']['v2'] 95 | else: 96 | # SD 1.5 97 | unet_config = config['reference_unet_config']['default'] 98 | denoise_unet_config = config['denoising_unet_config']['default'] 99 | 100 | 101 | reference_unet = CustomUNet(**config["reference_unet_config"]) 102 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | import cv2 5 | from decord import AudioReader 6 | from Net import EMOModel 7 | import decord 8 | 9 | # Load the trained EMO model 10 | model_path = 'emo_model_stage3.pth' 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | # Instantiate the EMOModel 14 | emo_model = EMOModel( 15 | vae=None, 16 | image_encoder=None, 17 | config={ 18 | "feature_dim": 512, 19 | "num_layers": 4, 20 | "audio_feature_dim": 128, 21 | "audio_num_layers": 2, 22 | "num_speed_buckets": 5, 23 | "speed_embedding_dim": 64, 24 | "temporal_module": "conv" 25 | } 26 | ).to(device) 27 | 28 | # Load the trained weights 29 | emo_model.load_state_dict(torch.load(model_path, map_location=device)) 30 | emo_model.eval() 31 | 32 | # Define the necessary transforms 33 | transform = transforms.Compose([ 34 | transforms.Resize((256, 256)), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 37 | ]) 38 | 39 | # Load the reference image 40 | reference_image_path = 'path/to/reference/image.jpg' 41 | reference_image = Image.open(reference_image_path).convert('RGB') 42 | reference_image = transform(reference_image).unsqueeze(0).to(device) 43 | 44 | # Load the audio frames 45 | audio_path = 'path/to/audio/file.mp3' 46 | audio_reader = AudioReader(audio_path, ctx=decord.cpu(), sample_rate=16000, mono=True) 47 | audio_frames = audio_reader[:] 48 | 49 | # Specify the target head rotation speed - WHAT ??? TODO - fix this 50 | target_speed = 0.5 51 | 52 | # Generate the video frames 53 | output_video_path = 'video.mp4' 54 | video_writer = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (512, 512)) 55 | 56 | with torch.no_grad(): 57 | for i in range(len(audio_frames)): 58 | audio_frame = audio_frames[i].unsqueeze(0).to(device) 59 | 60 | # Perform inference 61 | generated_frame = emo_model(reference_image, audio_frame, target_speed) 62 | 63 | # Convert the generated frame tensor to an array and adjust color channels 64 | generated_frame = generated_frame.squeeze(0).permute(1, 2, 0).cpu().numpy() 65 | generated_frame = (generated_frame * 0.5 + 0.5) * 255 66 | generated_frame = cv2.cvtColor(generated_frame.astype('uint8'), cv2.COLOR_RGB2BGR) 67 | 68 | video_writer.write(generated_frame) 69 | 70 | video_writer.release() 71 | -------------------------------------------------------------------------------- /junk/AttentionNotes.md: -------------------------------------------------------------------------------- 1 | ## Reference-Attention: 2 | Analogy: The reference-attention is like the film director who ensures that the actors maintain consistency in their appearance and style throughout the movie. Just as the director refers to the initial character designs and guides the actors to stay true to their roles, the reference-attention mechanism uses the reference image to maintain consistency in the generated video frames. 3 | ## Audio-Attention: 4 | Analogy: The audio-attention is similar to the music composer and sound designers in a movie production. They create and synchronize the musical score and sound effects with the visual content to enhance the emotional impact and narrative of the movie. Similarly, the audio-attention mechanism aligns the generated video frames with the corresponding audio features, ensuring that the character's movements and expressions match the tempo and mood of the audio. 5 | ## Self-Attention (Temporal Modules): 6 | Analogy: The self-attention mechanism in the temporal modules is like the film editor who ensures smooth transitions and coherence between different scenes in a movie. The editor carefully selects and arranges the footage to create a seamless flow and maintain the overall narrative structure. Similarly, the self-attention mechanism in the temporal modules helps in maintaining the temporal consistency and smooth transitions between the generated video frames, considering the context of the surrounding frames. 7 | ## Cross-Attention (Audio-Attention and Reference-Attention): 8 | Analogy: The cross-attention mechanism, used in both audio-attention and reference-attention, is like the communication between different departments in a movie production. For example, the cinematographer works closely with the director to understand their vision and capture the desired shots, while the sound designers collaborate with the composer to create a cohesive audio-visual experience. Similarly, the cross-attention mechanism allows for the exchange of information between different modalities (e.g., audio and visual features) to ensure synchronization and consistency in the generated video. 9 | 10 | These analogies help to understand the roles and interactions of the different attention mechanisms in the EMO model: 11 | 12 | Reference-attention ensures consistency with the reference image, like a director guiding the actors. 13 | Audio-attention synchronizes the video with the audio, like the music composer and sound designers. 14 | Self-attention in temporal modules maintains smooth transitions and coherence, like a film editor. 15 | Cross-attention enables communication and alignment between different modalities, like the collaboration between different departments in a movie production. 16 | By working together, these attention mechanisms contribute to generating expressive and coherent portrait videos that align with the given audio and reference image. -------------------------------------------------------------------------------- /junk/AudioAttention/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torchaudio 5 | import torchvision 6 | 7 | class AudioVisualDataset(Dataset): 8 | def __init__(self, dataset_dir): 9 | self.dataset_dir = dataset_dir 10 | self.audio_files = sorted([f for f in os.listdir(dataset_dir) if f.endswith('.wav')]) 11 | self.image_files = sorted([f for f in os.listdir(dataset_dir) if f.endswith('.png')]) 12 | 13 | assert len(self.audio_files) == len(self.image_files), "Number of audio files and image files should match" 14 | 15 | def __len__(self): 16 | return len(self.audio_files) 17 | 18 | def __getitem__(self, idx): 19 | audio_path = os.path.join(self.dataset_dir, self.audio_files[idx]) 20 | image_path = os.path.join(self.dataset_dir, self.image_files[idx]) 21 | 22 | # Load audio 23 | waveform, sample_rate = torchaudio.load(audio_path) 24 | assert waveform is not None, "Failed to load waveform" 25 | # Handling stereo audio by averaging the channels to convert to mono 26 | # This is just one approach; depending on your needs you might handle it differently 27 | if waveform.ndim == 2 and waveform.size(0) == 2: # Check if the audio is stereo 28 | waveform = waveform.mean(dim=0, keepdim=True) # Convert to mono by averaging the channels 29 | 30 | # Check if waveform has an unexpected batch dimension and remove it 31 | if waveform.ndim == 3: 32 | assert waveform.size(0) == 1, "Batch size should be 1" 33 | waveform = waveform.squeeze(0) # Remove the batch dimension 34 | 35 | # Resample audio if necessary 36 | target_sample_rate = 16000 # Set your desired sample rate 37 | if sample_rate != target_sample_rate: 38 | resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) 39 | waveform = resampler(waveform) 40 | 41 | assert waveform.ndim == 2, "Waveform should have 2 dimensions after processing (channels, time)" 42 | 43 | 44 | # Load image 45 | image = Image.open(image_path).convert('RGB') 46 | assert image is not None, "Failed to load image" 47 | 48 | # Convert image to tensor 49 | transform = torchvision.transforms.Compose([ 50 | torchvision.transforms.Resize((256, 256)), 51 | torchvision.transforms.ToTensor(), 52 | ]) 53 | image = transform(image) 54 | assert image is not None, "Failed to transform image" 55 | 56 | return { 57 | 'audio': waveform, 58 | 'image': image 59 | } -------------------------------------------------------------------------------- /junk/AudioAttention/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models, transforms 3 | from PIL import Image 4 | from transformers import Wav2Vec2Processor, Wav2Vec2Model 5 | import torchaudio 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | from model import Wav2VecFeatureExtractor,FeatureTransformLayer,AudioAttentionLayers 9 | # Ensure your inference device is set correctly 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | # invoke using 14 | # python ./junk/AudioAttention/inference.py 15 | 16 | # Load the trained model and other necessary modules for inference 17 | def load_model(checkpoint_path, device): 18 | # Load pre-trained ResNet-50 and Wav2Vec2.0 models for feature extraction 19 | resnet_model = models.resnet50(pretrained=True) 20 | resnet_model.eval() 21 | resnet_model.to(device) 22 | 23 | wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") 24 | wav2vec_model.eval() 25 | wav2vec_model.to(device) 26 | 27 | # Load your custom AudioAttention model 28 | audio_attention = AudioAttentionLayers(feature_dim=768, num_layers=3, device=device) 29 | audio_attention.to(device) 30 | 31 | # Load the checkpoint 32 | checkpoint = torch.load(checkpoint_path, map_location=device) 33 | audio_attention.load_state_dict(checkpoint['model_state_dict']) 34 | 35 | return resnet_model, wav2vec_model, audio_attention 36 | 37 | # Define the preprocessing for image and audio 38 | image_transforms = transforms.Compose([ 39 | transforms.Resize((256, 256)), 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 42 | ]) 43 | 44 | audio_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") 45 | 46 | # Assuming ResNet-50 feature maps are 2048 in depth before the final pooling and classification layers. 47 | visual_feature_transform = FeatureTransformLayer(input_dim=2048, output_dim=768, device=device) 48 | 49 | # Inference function 50 | def inference(image_path, audio_path, resnet_model, wav2vec_model, audio_attention, device): 51 | # Process the image 52 | image = Image.open(image_path).convert('RGB') 53 | image_tensor = image_transforms(image).unsqueeze(0).to(device) 54 | 55 | 56 | # Extract features from the image using ResNet 57 | with torch.no_grad(): 58 | visual_features = resnet_feature_extractor(image_tensor) 59 | # Flatten the spatial dimensions: 60 | visual_features = visual_features.view(visual_features.size(0), 2048, -1).mean(-1) 61 | visual_features = visual_feature_transform(visual_features) 62 | 63 | # Load and process the audio 64 | waveform, sample_rate = torchaudio.load(audio_path) 65 | waveform = waveform.to(device) 66 | 67 | # If the waveform tensor has more than 2 dimensions (batch and channels), reduce it to 2 dimensions 68 | if waveform.ndim > 2: 69 | waveform = waveform.squeeze(0) # Remove the batch dimension if it's present 70 | 71 | # Process the audio to extract features 72 | with torch.no_grad(): 73 | input_values = audio_processor(waveform, sampling_rate=sample_rate, return_tensors="pt").input_values 74 | # Check the dimension of input_values and remove any unnecessary dimensions 75 | if input_values.ndim > 2: 76 | input_values = input_values.squeeze(0) # Squeeze out the batch dimension if it's 1 77 | 78 | # Ensure input_values is 2D (batch, sequence_length) 79 | assert input_values.ndim == 2, f"Input_values should be 2D, but got {input_values.size()}" 80 | 81 | input_values = input_values.to(device) 82 | audio_features = wav2vec_model(input_values).last_hidden_state 83 | audio_features = audio_features.mean(dim=1) 84 | 85 | # Use the AudioAttention model to generate attended features and attention weights 86 | with torch.no_grad(): 87 | attended_features, attention_weights_list = audio_attention( 88 | visual_features, 89 | audio_features, 90 | return_attention_weights=True 91 | ) 92 | 93 | # Convert list of attention weights to a tensor if necessary 94 | if isinstance(attention_weights_list, list): 95 | # Assuming attention_weights_list is a list of tensors with the same shape 96 | attention_weights_tensor = torch.stack(attention_weights_list) 97 | else: 98 | attention_weights_tensor = attention_weights_list # if it's already a tensor 99 | 100 | # Now you can index attention_weights_tensor as needed 101 | # Ensure the indexing matches the dimensions of the tensor 102 | attention_matrix = attention_weights_tensor[0, 0, :].detach().cpu().numpy() 103 | 104 | return attended_features, attention_matrix 105 | 106 | 107 | # Load the model (make sure to replace 'checkpoint_epoch_10.pt' with your actual checkpoint file) 108 | resnet_model, wav2vec_model, audio_attention = load_model('./checkpoints/checkpoint_epoch_10.pt', device) 109 | 110 | # After loading the ResNet model: 111 | resnet_feature_extractor = torch.nn.Sequential(*list(resnet_model.children())[:-2]) 112 | resnet_feature_extractor.to(device) 113 | 114 | # Perform inference (replace 'path_to_image.jpg' and 'path_to_audio.wav' with your actual file paths) 115 | attended_features,attention_weights = inference('./junk/AudioAttention/synthetic_dataset/beep_0.png', './junk/AudioAttention/synthetic_dataset/beep_0.wav.wav', resnet_model, wav2vec_model, audio_attention, device) 116 | 117 | # Here, 'attended_features' would be the output of your model 118 | # print("Attended features:", attended_features) 119 | # Assuming attention_weights is a list of tensors 120 | if attention_weights: 121 | # Access the first layer's weights (check the dimensions to be sure) 122 | first_layer_weights = attention_weights[0] 123 | 124 | if first_layer_weights.ndim == 4: 125 | # Access the first head's attention weights of the first layer 126 | attention_matrix = first_layer_weights[0, 0, :, :].detach().cpu().numpy() 127 | else: 128 | raise ValueError("Unexpected number of dimensions in attention weights") 129 | 130 | 131 | # Plotting as a heatmap 132 | sns.heatmap(attention_matrix.reshape(-1, 1), cmap='viridis', cbar=True) 133 | plt.title('Attention Weights Over Audio Sequence') 134 | plt.xlabel('Attention Head') 135 | plt.ylabel('Audio Time Steps') 136 | plt.show() -------------------------------------------------------------------------------- /junk/AudioAttention/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from transformers import Wav2Vec2Processor, Wav2Vec2Model 7 | import soundfile as sf 8 | import numpy as np 9 | import torchaudio 10 | from dataset import AudioVisualDataset 11 | import torch 12 | from torchvision import models, transforms 13 | from PIL import Image 14 | import logging 15 | import os 16 | 17 | 18 | 19 | class Wav2VecFeatureExtractor: 20 | def __init__(self, model_name='facebook/wav2vec2-base-960h', device='cpu'): 21 | self.model_name = model_name 22 | self.device = device 23 | self.processor = Wav2Vec2Processor.from_pretrained(model_name) 24 | self.model = Wav2Vec2Model.from_pretrained(model_name).to(device) 25 | 26 | def extract_features(self, waveform, sample_rate=16000): 27 | """ 28 | Extract audio features from a waveform using Wav2Vec 2.0. 29 | 30 | Args: 31 | waveform (Tensor): The waveform of the audio. 32 | sample_rate (int): The sample rate of the waveform. 33 | 34 | Returns: 35 | torch.Tensor: Features extracted from the audio. 36 | """ 37 | 38 | # Ensure waveform is a 2D tensor (channel, time) 39 | if waveform.ndim == 1: 40 | waveform = waveform.unsqueeze(0) # Add channel dimension if it's not present 41 | elif waveform.ndim == 3: 42 | # If there is a batch dimension, we squeeze it out, because the model expects 2D input 43 | # This is true for a batch size of 1. If the batch size is greater, further handling is needed. 44 | waveform = waveform.squeeze(0) 45 | 46 | # Process the audio to extract features 47 | input_values = self.processor(waveform, sampling_rate=sample_rate, return_tensors="pt").input_values 48 | 49 | 50 | # Reshape input_values to remove the extraneous dimension if present 51 | if input_values.ndim == 3 and input_values.size(0) == 1: 52 | input_values = input_values.squeeze(0) 53 | 54 | # Check the shape of input_values to be sure it's 2D now 55 | assert input_values.ndim == 2, "Input_values should be 2D (batch, sequence_length) after squeezing" 56 | 57 | input_values = input_values.to(self.device) 58 | 59 | # Pass the input_values to the model 60 | with torch.no_grad(): 61 | features = self.model(input_values).last_hidden_state 62 | 63 | # Reduce the dimensions if necessary, usually you get (batch, seq_length, features) 64 | # You might want to average the sequence length dimension or handle it appropriately 65 | features = features.mean(dim=1) # Example of reducing the sequence length dimension by averaging 66 | return features 67 | 68 | 69 | 70 | 71 | # Implement the CrossAttentionLayer and AudioAttentionLayers 72 | class CrossAttentionLayer(nn.Module): 73 | def __init__(self, feature_dim, device): 74 | super(CrossAttentionLayer, self).__init__() 75 | self.query = nn.Linear(feature_dim, feature_dim).to(device) 76 | self.key = nn.Linear(feature_dim, feature_dim).to(device) 77 | self.value = nn.Linear(feature_dim, feature_dim).to(device) 78 | # Initialize scale on the correct device 79 | self.scale = torch.sqrt(torch.FloatTensor([feature_dim])).to(device) 80 | 81 | def forward(self, latent_code, audio_features, return_attention_weights=True): 82 | query = self.query(latent_code) 83 | key = self.key(audio_features) 84 | value = self.value(audio_features) 85 | 86 | attention_scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale 87 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 88 | attention_output = torch.matmul(attention_probs, value) 89 | 90 | if return_attention_weights: 91 | return attention_output, attention_probs # Return both output and attention weights 92 | return attention_output, None # If not returning weights, return None in place of weights 93 | 94 | 95 | 96 | class AudioAttentionLayers(nn.Module): 97 | def __init__(self, feature_dim, num_layers, device): 98 | super(AudioAttentionLayers, self).__init__() 99 | self.layers = nn.ModuleList([CrossAttentionLayer(feature_dim, device) for _ in range(num_layers)]) 100 | 101 | def forward(self, latent_code, audio_features, return_attention_weights=False): 102 | attention_weights = [] 103 | for layer in self.layers: 104 | output, weights = layer(latent_code, audio_features, return_attention_weights=True) 105 | latent_code = output + latent_code 106 | if return_attention_weights: 107 | attention_weights.append(weights) 108 | 109 | if return_attention_weights: 110 | # Stack the weights from each layer to form a tensor 111 | return latent_code, torch.stack(attention_weights) 112 | return latent_code 113 | 114 | class FeatureTransformLayer(nn.Module): 115 | def __init__(self, input_dim, output_dim,device): 116 | super(FeatureTransformLayer, self).__init__() 117 | self.transform = nn.Linear(input_dim, output_dim).to(device) 118 | 119 | def forward(self, features): 120 | # print("FeatureTransformLayer input device:", features.device) # Log device 121 | return self.transform(features) 122 | -------------------------------------------------------------------------------- /junk/AudioAttention/synthesize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision 3 | import torch 4 | 5 | # pip3 install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 6 | from diffusers import StableDiffusionPipeline 7 | from audiocraft.models import AudioGen 8 | from audiocraft.data.audio import audio_write 9 | 10 | # Set up the AudioGen model 11 | model = AudioGen.get_pretrained('facebook/audiogen-medium') 12 | model.set_generation_params(duration=1) # Generate 1-second audio samples 13 | 14 | # Set up the Stable Diffusion model 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=True) 17 | pipe = pipe.to(device) 18 | 19 | # Define the audio descriptions and corresponding image descriptions 20 | descriptions = ['Aah', 'Ooh', 'Mmm'] 21 | image_descriptions = ['A person with their mouth wide open, as if saying "Aah"', 'A person with pursed lips, as if saying "Ooh"', 'A person with a slight smile and closed eyes, as if savoring something and saying "Mmm"'] 22 | 23 | 24 | 25 | 26 | # A person with a confused expression, raised eyebrows, and slightly open mouth, as if saying "Huh" 27 | # A person with a disgusted expression, furrowed brows, and a frown, as if saying "Ugh" 28 | # A person with wide eyes, raised eyebrows, and an open mouth, as if expressing surprise and saying "Wow" 29 | # A person with a finger to their lips, as if saying "Shh" 30 | # A person wiping their brow with a relieved expression, as if saying "Phew" 31 | # A person with a touched or sympathetic expression, as if saying "Aww" 32 | # A person with a disapproving expression, clicking their tongue, as if saying "Tsk" 33 | # A person shivering with a cold expression, as if saying "Brrr" 34 | 35 | # Set up the output directory 36 | output_dir = 'synthetic_dataset' 37 | os.makedirs(output_dir, exist_ok=True) 38 | 39 | # Generate synthetic audio-visual pairs 40 | num_samples = 10 # Number of samples to generate for each description pair 41 | for desc, img_desc in zip(descriptions, image_descriptions): 42 | for i in range(num_samples): 43 | # Generate audio 44 | wav = model.generate([desc]) 45 | audio_filename = f"{desc}_{i}" 46 | audio_path = os.path.join(output_dir, audio_filename) 47 | audio_write(audio_path, wav[0].cpu(), model.sample_rate, strategy="loudness") 48 | 49 | # Generate corresponding image using Stable Diffusion 50 | image_filename = f"{desc}_{i}.png" 51 | image_path = os.path.join(output_dir, image_filename) 52 | with torch.autocast("cuda"): 53 | image = pipe(img_desc).images[0] 54 | image.save(image_path) 55 | 56 | print(f"Generated {audio_filename} and {image_filename}") 57 | 58 | print("Synthetic dataset generation completed.") -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_0.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_0.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_0.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_1.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_1.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_1.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_2.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_2.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_2.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_3.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_3.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_3.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_4.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_4.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_4.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_5.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_5.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_5.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_6.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_6.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_6.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_7.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_7.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_7.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_8.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_8.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_8.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_9.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/beep_9.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/beep_9.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_0.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_0.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_0.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_1.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_1.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_1.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_2.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_2.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_2.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_3.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_3.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_3.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_4.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_4.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_4.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_5.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_5.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_5.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_6.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_6.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_6.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_7.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_7.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_7.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_8.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_8.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_8.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_9.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/buzz_9.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/buzz_9.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_0.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_0.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_0.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_1.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_1.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_1.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_2.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_2.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_2.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_3.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_3.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_3.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_4.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_4.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_4.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_5.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_5.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_5.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_6.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_6.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_6.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_7.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_7.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_7.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_8.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_8.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_8.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_9.png -------------------------------------------------------------------------------- /junk/AudioAttention/synthetic_dataset/tick_9.wav.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/AudioAttention/synthetic_dataset/tick_9.wav.wav -------------------------------------------------------------------------------- /junk/AudioAttention/train.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from transformers import Wav2Vec2Processor, Wav2Vec2Model 7 | import soundfile as sf 8 | import numpy as np 9 | import torchaudio 10 | from dataset import AudioVisualDataset 11 | import torch 12 | from torchvision import models, transforms 13 | from PIL import Image 14 | import logging 15 | import os 16 | from model import Wav2VecFeatureExtractor,FeatureTransformLayer,AudioAttentionLayers 17 | 18 | # Load a pre-trained ResNet-50 model 19 | model = models.resnet50(pretrained=True) 20 | model.eval() # Set the model to inference mode 21 | 22 | 23 | # run this like 24 | # python ./junk/AudioAttention/train.py 25 | 26 | 27 | # Setup logging 28 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 29 | logger = logging.getLogger() 30 | 31 | # Checkpoint directory 32 | checkpoint_dir = './checkpoints' 33 | os.makedirs(checkpoint_dir, exist_ok=True) 34 | 35 | 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | # Instantiate a transformation layer for visual features 39 | visual_feature_transform = FeatureTransformLayer(input_dim=2048, output_dim=768,device=device) 40 | # visual_feature_transform = visual_feature_transform.to(device) # Move to GPU if necessary 41 | # 42 | 43 | # Instantiate the Wav2VecFeatureExtractor 44 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") 45 | model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") 46 | 47 | # Instantiate the AudioAttentionLayers 48 | audio_attention = AudioAttentionLayers(feature_dim=768, num_layers=3, device=device) 49 | 50 | 51 | 52 | # Pre-load and set the ResNet-50 model to inference mode 53 | resnet_model = models.resnet50(pretrained=True) 54 | resnet_model.eval() 55 | resnet_model.to(device) # Assuming using CUDA 56 | 57 | # Remove the final fully connected layer to extract features instead of class predictions 58 | resnet_feature_extractor = torch.nn.Sequential(*(list(resnet_model.children())[:-1])) 59 | resnet_feature_extractor.to(device) 60 | 61 | # Instantiate the dataset and data loader 62 | dataset = AudioVisualDataset('./junk/AudioAttention/synthetic_dataset') # launching with vscode / settings 63 | data_loader = DataLoader(dataset, batch_size=32, shuffle=True) 64 | 65 | # Set up the optimizer and loss function 66 | optimizer = optim.Adam(audio_attention.parameters(), lr=0.001) 67 | criterion = nn.MSELoss() 68 | 69 | 70 | # Instantiate the Wav2VecFeatureExtractor 71 | audio_feature_extractor = Wav2VecFeatureExtractor('facebook/wav2vec2-base-960h',device) 72 | 73 | # Training loop 74 | num_epochs = 10 75 | 76 | 77 | 78 | # Initialize metric accumulators 79 | total_loss = 0.0 80 | total_samples = 0 81 | log_interval = 10 82 | for epoch in range(num_epochs): 83 | epoch_loss = 0.0 84 | for batch_idx, batch in enumerate(data_loader): 85 | # Check that 'audio' and 'image' keys exist in the batch 86 | assert 'audio' in batch and 'image' in batch, "Batch must contain 'audio' and 'image' keys" 87 | 88 | # Assuming 'audio' and 'image' are correctly loaded and 'audio' is waveform 89 | waveform = batch['audio'] # Assuming sample_rate is 16000 for all waveforms 90 | waveform = waveform.to(device) 91 | # Assuming 'waveform' is a 3D tensor from the DataLoader 92 | for i in range(waveform.size(0)): # Iterate over the batch 93 | single_waveform = waveform[i] # Extract a 2D tensor (channels, time) 94 | # Process single_waveform as needed 95 | 96 | # Extract audio features 97 | audio_features = audio_feature_extractor.extract_features(single_waveform) 98 | 99 | # Assuming batch["image"] is a tensor of shape (batch_size, C, H, W) 100 | image = batch["image"].to(device) 101 | 102 | 103 | # Ensure the waveform and image tensors are not empty and have expected dimensions 104 | assert waveform.ndim == 3, "Waveform tensor should have 3 dimensions (batch, channels, length)" 105 | assert image.ndim == 4, "Image tensor should have 4 dimensions (batch, channels, height, width)" 106 | 107 | # Preprocess and extract features for the whole batch with torch.no_grad(): 108 | visual_features = resnet_feature_extractor(image) 109 | 110 | 111 | # Depending on the resnet_feature_extractor output, you might need to adjust its dimensions 112 | # If the feature extractor outputs a tensor of shape (batch_size, features, 1, 1), 113 | # you should remove the last two dimensions: 114 | visual_features = torch.flatten(visual_features, start_dim=1) 115 | 116 | # Move visual features to the same device as the transformation layer 117 | visual_features = visual_features.to(device) 118 | # Transform visual features to match the attention layer input dimensionality 119 | visual_features_transformed = visual_feature_transform(visual_features) 120 | 121 | # Assert that features have the correct dimensions 122 | # print("🎉 audio_features:",audio_features.ndim) 123 | assert audio_features.ndim == 2, "Audio features should be 2D (batch, features)" 124 | assert visual_features_transformed.ndim == 2, "Visual features should be 2D (batch, features)" 125 | 126 | optimizer.zero_grad() 127 | attended_features,attention_weights = audio_attention( 128 | visual_features_transformed, 129 | audio_features, 130 | return_attention_weights=True 131 | ) 132 | # Assert the attended features have the correct shape 133 | assert attended_features.ndim == 2, "Attended features should be 2D (batch, features)" 134 | 135 | 136 | loss = criterion(attended_features, visual_features_transformed) 137 | loss.backward() 138 | optimizer.step() 139 | 140 | total_loss += loss.item() * waveform.size(0) 141 | total_samples += waveform.size(0) 142 | # Accumulate loss for the epoch 143 | epoch_loss += loss.item() 144 | 145 | # Logging within the batch loop if needed 146 | if batch_idx % log_interval == 0: # Assuming you define log_interval 147 | logger.info(f'Epoch: {epoch+1} [{batch_idx * len(waveform)}/{len(data_loader.dataset)} ' 148 | f'({100. * batch_idx / len(data_loader):.0f}%)]\tLoss: {loss.item():.6f}') 149 | 150 | 151 | print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}") 152 | 153 | # Evaluation and visualization 154 | # Calculate the average loss across all validation samples 155 | average_loss = total_loss / total_samples 156 | 157 | 158 | 159 | # Comparison with baselines 160 | # Here, you should load the performance metrics of baseline models or previous studies 161 | # For demonstration, let's assume you have these as constants 162 | baseline_loss = 0.05 # hypothetical value 163 | 164 | # Average loss for the epoch 165 | epoch_loss /= len(data_loader.dataset) 166 | logger.info(f'Epoch {epoch+1} finished, average loss: {epoch_loss:.6f}') 167 | 168 | # Saving checkpoints after the epoch 169 | checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt') 170 | torch.save({ 171 | 'epoch': epoch+1, 172 | 'model_state_dict': audio_attention.state_dict(), 173 | 'optimizer_state_dict': optimizer.state_dict(), 174 | 'loss': epoch_loss, 175 | }, checkpoint_path) 176 | 177 | logger.info(f'Checkpoint saved to {checkpoint_path}') 178 | 179 | # Rest of your code... 180 | 181 | # Instead of printing, use logger to log the final results 182 | logger.info(f'Average validation loss: {average_loss:.4f}') 183 | logger.info(f'Comparison with Baseline: Baseline Loss - {baseline_loss:.4f}, Model Loss - {average_loss:.4f}') -------------------------------------------------------------------------------- /junk/BroadcastinExample.py: -------------------------------------------------------------------------------- 1 | from manim import Scene,Text,Write,ReplacementTransform,FadeOut,UP,Matrix 2 | import numpy as np 3 | 4 | 5 | 6 | # 7 | # sudo apt-get install texlive-latex-extra # 300MB 8 | 9 | 10 | # manim -pql BroadcastinExample.py BroadcastingExample -r 1280,720 11 | class BroadcastingExample(Scene): 12 | def construct(self): 13 | # Create tensors 14 | latent_codes = np.random.randn(4, 512, 10) 15 | speed_embeddings = np.random.randn(64) 16 | 17 | # Create Manim arrays 18 | latent_codes_array = Matrix(latent_codes.astype(int), v_buff=0.5, h_buff=1) 19 | speed_embeddings_array = Matrix(speed_embeddings.astype(int), v_buff=0.5) 20 | 21 | # Create text labels 22 | latent_codes_text = Text("Latent Codes").scale(0.7).next_to(latent_codes_array, UP) 23 | speed_embeddings_text = Text("Speed Embeddings").scale(0.7).next_to(speed_embeddings_array, UP) 24 | 25 | # Animate the creation of arrays and labels 26 | self.play( 27 | Write(latent_codes_array), 28 | Write(speed_embeddings_array), 29 | Write(latent_codes_text), 30 | Write(speed_embeddings_text), 31 | ) 32 | self.wait(2) 33 | 34 | # Expand the speed embeddings 35 | expanded_speed_embeddings = np.expand_dims(np.expand_dims(speed_embeddings, axis=0), axis=-1) 36 | expanded_speed_embeddings = np.tile(expanded_speed_embeddings, (4, 1, 10)) 37 | expanded_speed_embeddings_array = Matrix(expanded_speed_embeddings.astype(int), v_buff=0.5, h_buff=1) 38 | expanded_speed_embeddings_text = Text("Expanded Speed Embeddings").scale(0.7).next_to(expanded_speed_embeddings_array, UP) 39 | 40 | # Animate the expansion of speed embeddings 41 | self.play( 42 | ReplacementTransform(speed_embeddings_array, expanded_speed_embeddings_array), 43 | ReplacementTransform(speed_embeddings_text, expanded_speed_embeddings_text), 44 | ) 45 | self.wait(2) 46 | 47 | # Perform broadcasting and element-wise addition 48 | combined_features = latent_codes + expanded_speed_embeddings 49 | combined_features_array = Matrix(combined_features.astype(int), v_buff=0.5, h_buff=1) 50 | combined_features_text = Text("Combined Features").scale(0.7).next_to(combined_features_array, UP) 51 | 52 | # Animate the broadcasting and addition 53 | self.play( 54 | ReplacementTransform(latent_codes_array, combined_features_array), 55 | ReplacementTransform(expanded_speed_embeddings_array, combined_features_array), 56 | ReplacementTransform(latent_codes_text, combined_features_text), 57 | FadeOut(expanded_speed_embeddings_text), 58 | ) 59 | self.wait(2) 60 | 61 | # Show the final result 62 | self.play(combined_features_array.animate.scale(1.2)) 63 | self.wait(3) -------------------------------------------------------------------------------- /junk/M2Ohb0FAaJU_1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/M2Ohb0FAaJU_1.mp4 -------------------------------------------------------------------------------- /junk/M2Ohb0FAaJU_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/M2Ohb0FAaJU_1.wav -------------------------------------------------------------------------------- /junk/bla.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/bla.png -------------------------------------------------------------------------------- /junk/frame_0094_debug.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/frame_0094_debug.jpg -------------------------------------------------------------------------------- /junk/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johndpope/Emote-hack/f4e66ac6ba4f6bd95584b856d135f3362a0d86eb/junk/pipeline.png -------------------------------------------------------------------------------- /magicanimate/models/attention.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Copyright 2023 The HuggingFace Team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | from dataclasses import dataclass 21 | from typing import Optional 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | from torch import nn 26 | 27 | from diffusers.configuration_utils import ConfigMixin, register_to_config 28 | from diffusers.models.modeling_utils import ModelMixin 29 | from diffusers.utils import BaseOutput 30 | from diffusers.utils.import_utils import is_xformers_available 31 | from diffusers.models.attention import FeedForward, AdaLayerNorm 32 | from diffusers.models.attention import Attention as CrossAttention 33 | 34 | from einops import rearrange, repeat 35 | 36 | @dataclass 37 | class Transformer3DModelOutput(BaseOutput): 38 | sample: torch.FloatTensor 39 | 40 | 41 | if is_xformers_available(): 42 | import xformers 43 | import xformers.ops 44 | else: 45 | xformers = None 46 | 47 | 48 | class Transformer3DModel(ModelMixin, ConfigMixin): 49 | @register_to_config 50 | def __init__( 51 | self, 52 | num_attention_heads: int = 16, 53 | attention_head_dim: int = 88, 54 | in_channels: Optional[int] = None, 55 | num_layers: int = 1, 56 | dropout: float = 0.0, 57 | norm_num_groups: int = 32, 58 | cross_attention_dim: Optional[int] = None, 59 | attention_bias: bool = False, 60 | activation_fn: str = "geglu", 61 | num_embeds_ada_norm: Optional[int] = None, 62 | use_linear_projection: bool = False, 63 | only_cross_attention: bool = False, 64 | upcast_attention: bool = False, 65 | 66 | unet_use_cross_frame_attention=None, 67 | unet_use_temporal_attention=None, 68 | ): 69 | super().__init__() 70 | self.use_linear_projection = use_linear_projection 71 | self.num_attention_heads = num_attention_heads 72 | self.attention_head_dim = attention_head_dim 73 | inner_dim = num_attention_heads * attention_head_dim 74 | 75 | # Define input layers 76 | self.in_channels = in_channels 77 | 78 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 79 | if use_linear_projection: 80 | self.proj_in = nn.Linear(in_channels, inner_dim) 81 | else: 82 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 83 | 84 | # Define transformers blocks 85 | self.transformer_blocks = nn.ModuleList( 86 | [ 87 | BasicTransformerBlock( 88 | inner_dim, 89 | num_attention_heads, 90 | attention_head_dim, 91 | dropout=dropout, 92 | cross_attention_dim=cross_attention_dim, 93 | activation_fn=activation_fn, 94 | num_embeds_ada_norm=num_embeds_ada_norm, 95 | attention_bias=attention_bias, 96 | only_cross_attention=only_cross_attention, 97 | upcast_attention=upcast_attention, 98 | 99 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 100 | unet_use_temporal_attention=unet_use_temporal_attention, 101 | ) 102 | for d in range(num_layers) 103 | ] 104 | ) 105 | 106 | # 4. Define output layers 107 | if use_linear_projection: 108 | self.proj_out = nn.Linear(in_channels, inner_dim) 109 | else: 110 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 111 | 112 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 113 | # Input 114 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 115 | video_length = hidden_states.shape[2] 116 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 117 | # JH: need not repeat when a list of prompts are given 118 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 119 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 120 | 121 | batch, channel, height, weight = hidden_states.shape 122 | residual = hidden_states 123 | 124 | hidden_states = self.norm(hidden_states) 125 | if not self.use_linear_projection: 126 | hidden_states = self.proj_in(hidden_states) 127 | inner_dim = hidden_states.shape[1] 128 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 129 | else: 130 | inner_dim = hidden_states.shape[1] 131 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 132 | hidden_states = self.proj_in(hidden_states) 133 | 134 | # Blocks 135 | for block in self.transformer_blocks: 136 | hidden_states = block( 137 | hidden_states, 138 | encoder_hidden_states=encoder_hidden_states, 139 | timestep=timestep, 140 | video_length=video_length 141 | ) 142 | 143 | # Output 144 | if not self.use_linear_projection: 145 | hidden_states = ( 146 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 147 | ) 148 | hidden_states = self.proj_out(hidden_states) 149 | else: 150 | hidden_states = self.proj_out(hidden_states) 151 | hidden_states = ( 152 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 153 | ) 154 | 155 | output = hidden_states + residual 156 | 157 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 158 | if not return_dict: 159 | return (output,) 160 | 161 | return Transformer3DModelOutput(sample=output) 162 | 163 | 164 | class BasicTransformerBlock(nn.Module): 165 | def __init__( 166 | self, 167 | dim: int, 168 | num_attention_heads: int, 169 | attention_head_dim: int, 170 | dropout=0.0, 171 | cross_attention_dim: Optional[int] = None, 172 | activation_fn: str = "geglu", 173 | num_embeds_ada_norm: Optional[int] = None, 174 | attention_bias: bool = False, 175 | only_cross_attention: bool = False, 176 | upcast_attention: bool = False, 177 | 178 | unet_use_cross_frame_attention = None, 179 | unet_use_temporal_attention = None, 180 | ): 181 | super().__init__() 182 | self.only_cross_attention = only_cross_attention 183 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 184 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 185 | self.unet_use_temporal_attention = unet_use_temporal_attention 186 | 187 | # SC-Attn 188 | assert unet_use_cross_frame_attention is not None 189 | if unet_use_cross_frame_attention: 190 | self.attn1 = SparseCausalAttention2D( 191 | query_dim=dim, 192 | heads=num_attention_heads, 193 | dim_head=attention_head_dim, 194 | dropout=dropout, 195 | bias=attention_bias, 196 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 197 | upcast_attention=upcast_attention, 198 | ) 199 | else: 200 | self.attn1 = CrossAttention( 201 | query_dim=dim, 202 | heads=num_attention_heads, 203 | dim_head=attention_head_dim, 204 | dropout=dropout, 205 | bias=attention_bias, 206 | upcast_attention=upcast_attention, 207 | ) 208 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 209 | 210 | # Cross-Attn 211 | if cross_attention_dim is not None: 212 | self.attn2 = CrossAttention( 213 | query_dim=dim, 214 | cross_attention_dim=cross_attention_dim, 215 | heads=num_attention_heads, 216 | dim_head=attention_head_dim, 217 | dropout=dropout, 218 | bias=attention_bias, 219 | upcast_attention=upcast_attention, 220 | ) 221 | else: 222 | self.attn2 = None 223 | 224 | if cross_attention_dim is not None: 225 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 226 | else: 227 | self.norm2 = None 228 | 229 | # Feed-forward 230 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 231 | self.norm3 = nn.LayerNorm(dim) 232 | self.use_ada_layer_norm_zero = False 233 | 234 | # Temp-Attn 235 | assert unet_use_temporal_attention is not None 236 | if unet_use_temporal_attention: 237 | self.attn_temp = CrossAttention( 238 | query_dim=dim, 239 | heads=num_attention_heads, 240 | dim_head=attention_head_dim, 241 | dropout=dropout, 242 | bias=attention_bias, 243 | upcast_attention=upcast_attention, 244 | ) 245 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 246 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 247 | 248 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): 249 | if not is_xformers_available(): 250 | print("Here is how to install it") 251 | raise ModuleNotFoundError( 252 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 253 | " xformers", 254 | name="xformers", 255 | ) 256 | elif not torch.cuda.is_available(): 257 | raise ValueError( 258 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 259 | " available for GPU " 260 | ) 261 | else: 262 | try: 263 | # Make sure we can run the memory efficient attention 264 | _ = xformers.ops.memory_efficient_attention( 265 | torch.randn((1, 2, 40), device="cuda"), 266 | torch.randn((1, 2, 40), device="cuda"), 267 | torch.randn((1, 2, 40), device="cuda"), 268 | ) 269 | except Exception as e: 270 | raise e 271 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 272 | if self.attn2 is not None: 273 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 274 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 275 | 276 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 277 | # SparseCausal-Attention 278 | norm_hidden_states = ( 279 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 280 | ) 281 | 282 | # if self.only_cross_attention: 283 | # hidden_states = ( 284 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 285 | # ) 286 | # else: 287 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 288 | 289 | # pdb.set_trace() 290 | if self.unet_use_cross_frame_attention: 291 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 292 | else: 293 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 294 | 295 | if self.attn2 is not None: 296 | # Cross-Attention 297 | norm_hidden_states = ( 298 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 299 | ) 300 | hidden_states = ( 301 | self.attn2( 302 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 303 | ) 304 | + hidden_states 305 | ) 306 | 307 | # Feed-forward 308 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 309 | 310 | # Temporal-Attention 311 | if self.unet_use_temporal_attention: 312 | d = hidden_states.shape[1] 313 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 314 | norm_hidden_states = ( 315 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 316 | ) 317 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 318 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 319 | 320 | return hidden_states 321 | -------------------------------------------------------------------------------- /magicanimate/models/embeddings.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Copyright 2023 The HuggingFace Team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import math 21 | from typing import Optional 22 | 23 | import numpy as np 24 | import torch 25 | from torch import nn 26 | 27 | 28 | def get_timestep_embedding( 29 | timesteps: torch.Tensor, 30 | embedding_dim: int, 31 | flip_sin_to_cos: bool = False, 32 | downscale_freq_shift: float = 1, 33 | scale: float = 1, 34 | max_period: int = 10000, 35 | ): 36 | """ 37 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 38 | 39 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 40 | These may be fractional. 41 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 42 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 43 | """ 44 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 45 | 46 | half_dim = embedding_dim // 2 47 | exponent = -math.log(max_period) * torch.arange( 48 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 49 | ) 50 | exponent = exponent / (half_dim - downscale_freq_shift) 51 | 52 | emb = torch.exp(exponent) 53 | emb = timesteps[:, None].float() * emb[None, :] 54 | 55 | # scale embeddings 56 | emb = scale * emb 57 | 58 | # concat sine and cosine embeddings 59 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 60 | 61 | # flip sine and cosine embeddings 62 | if flip_sin_to_cos: 63 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 64 | 65 | # zero pad 66 | if embedding_dim % 2 == 1: 67 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 68 | return emb 69 | 70 | 71 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 72 | """ 73 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 74 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 75 | """ 76 | grid_h = np.arange(grid_size, dtype=np.float32) 77 | grid_w = np.arange(grid_size, dtype=np.float32) 78 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 79 | grid = np.stack(grid, axis=0) 80 | 81 | grid = grid.reshape([2, 1, grid_size, grid_size]) 82 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 83 | if cls_token and extra_tokens > 0: 84 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 85 | return pos_embed 86 | 87 | 88 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 89 | if embed_dim % 2 != 0: 90 | raise ValueError("embed_dim must be divisible by 2") 91 | 92 | # use half of dimensions to encode grid_h 93 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 94 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 95 | 96 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 97 | return emb 98 | 99 | 100 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 101 | """ 102 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 103 | """ 104 | if embed_dim % 2 != 0: 105 | raise ValueError("embed_dim must be divisible by 2") 106 | 107 | omega = np.arange(embed_dim // 2, dtype=np.float64) 108 | omega /= embed_dim / 2.0 109 | omega = 1.0 / 10000**omega # (D/2,) 110 | 111 | pos = pos.reshape(-1) # (M,) 112 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 113 | 114 | emb_sin = np.sin(out) # (M, D/2) 115 | emb_cos = np.cos(out) # (M, D/2) 116 | 117 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 118 | return emb 119 | 120 | 121 | class PatchEmbed(nn.Module): 122 | """2D Image to Patch Embedding""" 123 | 124 | def __init__( 125 | self, 126 | height=224, 127 | width=224, 128 | patch_size=16, 129 | in_channels=3, 130 | embed_dim=768, 131 | layer_norm=False, 132 | flatten=True, 133 | bias=True, 134 | ): 135 | super().__init__() 136 | 137 | num_patches = (height // patch_size) * (width // patch_size) 138 | self.flatten = flatten 139 | self.layer_norm = layer_norm 140 | 141 | self.proj = nn.Conv2d( 142 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias 143 | ) 144 | if layer_norm: 145 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 146 | else: 147 | self.norm = None 148 | 149 | pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) 150 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) 151 | 152 | def forward(self, latent): 153 | latent = self.proj(latent) 154 | if self.flatten: 155 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 156 | if self.layer_norm: 157 | latent = self.norm(latent) 158 | return latent + self.pos_embed 159 | 160 | 161 | class TimestepEmbedding(nn.Module): 162 | def __init__( 163 | self, 164 | in_channels: int, 165 | time_embed_dim: int, 166 | act_fn: str = "silu", 167 | out_dim: int = None, 168 | post_act_fn: Optional[str] = None, 169 | cond_proj_dim=None, 170 | ): 171 | super().__init__() 172 | 173 | self.linear_1 = nn.Linear(in_channels, time_embed_dim) 174 | 175 | if cond_proj_dim is not None: 176 | self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) 177 | else: 178 | self.cond_proj = None 179 | 180 | if act_fn == "silu": 181 | self.act = nn.SiLU() 182 | elif act_fn == "mish": 183 | self.act = nn.Mish() 184 | elif act_fn == "gelu": 185 | self.act = nn.GELU() 186 | else: 187 | raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") 188 | 189 | if out_dim is not None: 190 | time_embed_dim_out = out_dim 191 | else: 192 | time_embed_dim_out = time_embed_dim 193 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) 194 | 195 | if post_act_fn is None: 196 | self.post_act = None 197 | elif post_act_fn == "silu": 198 | self.post_act = nn.SiLU() 199 | elif post_act_fn == "mish": 200 | self.post_act = nn.Mish() 201 | elif post_act_fn == "gelu": 202 | self.post_act = nn.GELU() 203 | else: 204 | raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") 205 | 206 | def forward(self, sample, condition=None): 207 | if condition is not None: 208 | sample = sample + self.cond_proj(condition) 209 | sample = self.linear_1(sample) 210 | 211 | if self.act is not None: 212 | sample = self.act(sample) 213 | 214 | sample = self.linear_2(sample) 215 | 216 | if self.post_act is not None: 217 | sample = self.post_act(sample) 218 | return sample 219 | 220 | 221 | class Timesteps(nn.Module): 222 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): 223 | super().__init__() 224 | self.num_channels = num_channels 225 | self.flip_sin_to_cos = flip_sin_to_cos 226 | self.downscale_freq_shift = downscale_freq_shift 227 | 228 | def forward(self, timesteps): 229 | t_emb = get_timestep_embedding( 230 | timesteps, 231 | self.num_channels, 232 | flip_sin_to_cos=self.flip_sin_to_cos, 233 | downscale_freq_shift=self.downscale_freq_shift, 234 | ) 235 | return t_emb 236 | 237 | 238 | class GaussianFourierProjection(nn.Module): 239 | """Gaussian Fourier embeddings for noise levels.""" 240 | 241 | def __init__( 242 | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False 243 | ): 244 | super().__init__() 245 | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 246 | self.log = log 247 | self.flip_sin_to_cos = flip_sin_to_cos 248 | 249 | if set_W_to_weight: 250 | # to delete later 251 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 252 | 253 | self.weight = self.W 254 | 255 | def forward(self, x): 256 | if self.log: 257 | x = torch.log(x) 258 | 259 | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi 260 | 261 | if self.flip_sin_to_cos: 262 | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) 263 | else: 264 | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 265 | return out 266 | 267 | 268 | class ImagePositionalEmbeddings(nn.Module): 269 | """ 270 | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the 271 | height and width of the latent space. 272 | 273 | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 274 | 275 | For VQ-diffusion: 276 | 277 | Output vector embeddings are used as input for the transformer. 278 | 279 | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. 280 | 281 | Args: 282 | num_embed (`int`): 283 | Number of embeddings for the latent pixels embeddings. 284 | height (`int`): 285 | Height of the latent image i.e. the number of height embeddings. 286 | width (`int`): 287 | Width of the latent image i.e. the number of width embeddings. 288 | embed_dim (`int`): 289 | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. 290 | """ 291 | 292 | def __init__( 293 | self, 294 | num_embed: int, 295 | height: int, 296 | width: int, 297 | embed_dim: int, 298 | ): 299 | super().__init__() 300 | 301 | self.height = height 302 | self.width = width 303 | self.num_embed = num_embed 304 | self.embed_dim = embed_dim 305 | 306 | self.emb = nn.Embedding(self.num_embed, embed_dim) 307 | self.height_emb = nn.Embedding(self.height, embed_dim) 308 | self.width_emb = nn.Embedding(self.width, embed_dim) 309 | 310 | def forward(self, index): 311 | emb = self.emb(index) 312 | 313 | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) 314 | 315 | # 1 x H x D -> 1 x H x 1 x D 316 | height_emb = height_emb.unsqueeze(2) 317 | 318 | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) 319 | 320 | # 1 x W x D -> 1 x 1 x W x D 321 | width_emb = width_emb.unsqueeze(1) 322 | 323 | pos_emb = height_emb + width_emb 324 | 325 | # 1 x H x W x D -> 1 x L xD 326 | pos_emb = pos_emb.view(1, self.height * self.width, -1) 327 | 328 | emb = emb + pos_emb[:, : emb.shape[1], :] 329 | 330 | return emb 331 | 332 | 333 | class LabelEmbedding(nn.Module): 334 | """ 335 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 336 | 337 | Args: 338 | num_classes (`int`): The number of classes. 339 | hidden_size (`int`): The size of the vector embeddings. 340 | dropout_prob (`float`): The probability of dropping a label. 341 | """ 342 | 343 | def __init__(self, num_classes, hidden_size, dropout_prob): 344 | super().__init__() 345 | use_cfg_embedding = dropout_prob > 0 346 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 347 | self.num_classes = num_classes 348 | self.dropout_prob = dropout_prob 349 | 350 | def token_drop(self, labels, force_drop_ids=None): 351 | """ 352 | Drops labels to enable classifier-free guidance. 353 | """ 354 | if force_drop_ids is None: 355 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 356 | else: 357 | drop_ids = torch.tensor(force_drop_ids == 1) 358 | labels = torch.where(drop_ids, self.num_classes, labels) 359 | return labels 360 | 361 | def forward(self, labels, force_drop_ids=None): 362 | use_dropout = self.dropout_prob > 0 363 | if (self.training and use_dropout) or (force_drop_ids is not None): 364 | labels = self.token_drop(labels, force_drop_ids) 365 | embeddings = self.embedding_table(labels) 366 | return embeddings 367 | 368 | 369 | class CombinedTimestepLabelEmbeddings(nn.Module): 370 | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): 371 | super().__init__() 372 | 373 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) 374 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 375 | self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) 376 | 377 | def forward(self, timestep, class_labels, hidden_dtype=None): 378 | timesteps_proj = self.time_proj(timestep) 379 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) 380 | 381 | class_labels = self.class_embedder(class_labels) # (N, D) 382 | 383 | conditioning = timesteps_emb + class_labels # (N, D) 384 | 385 | return conditioning -------------------------------------------------------------------------------- /magicanimate/models/motion_module.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/guoyww/AnimateDiff 8 | from dataclasses import dataclass 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | from diffusers.utils import BaseOutput 15 | from diffusers.utils.import_utils import is_xformers_available 16 | from diffusers.models.attention import FeedForward 17 | from magicanimate.models.orig_attention import CrossAttention 18 | 19 | from einops import rearrange, repeat 20 | import math 21 | 22 | 23 | def zero_module(module): 24 | # Zero out the parameters of a module and return it. 25 | for p in module.parameters(): 26 | p.detach().zero_() 27 | return module 28 | 29 | 30 | @dataclass 31 | class TemporalTransformer3DModelOutput(BaseOutput): 32 | sample: torch.FloatTensor 33 | 34 | 35 | if is_xformers_available(): 36 | import xformers 37 | import xformers.ops 38 | else: 39 | xformers = None 40 | 41 | 42 | def get_motion_module( 43 | in_channels, 44 | motion_module_type: str, 45 | motion_module_kwargs: dict 46 | ): 47 | if motion_module_type == "Vanilla": 48 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 49 | else: 50 | raise ValueError 51 | 52 | 53 | class VanillaTemporalModule(nn.Module): 54 | def __init__( 55 | self, 56 | in_channels, 57 | num_attention_heads = 8, 58 | num_transformer_block = 2, 59 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 60 | cross_frame_attention_mode = None, 61 | temporal_position_encoding = False, 62 | temporal_position_encoding_max_len = 24, 63 | temporal_attention_dim_div = 1, 64 | zero_initialize = True, 65 | ): 66 | super().__init__() 67 | 68 | self.temporal_transformer = TemporalTransformer3DModel( 69 | in_channels=in_channels, 70 | num_attention_heads=num_attention_heads, 71 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 72 | num_layers=num_transformer_block, 73 | attention_block_types=attention_block_types, 74 | cross_frame_attention_mode=cross_frame_attention_mode, 75 | temporal_position_encoding=temporal_position_encoding, 76 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 77 | ) 78 | 79 | if zero_initialize: 80 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 81 | 82 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 83 | hidden_states = input_tensor 84 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 85 | 86 | output = hidden_states 87 | return output 88 | 89 | 90 | class TemporalTransformer3DModel(nn.Module): 91 | def __init__( 92 | self, 93 | in_channels, 94 | num_attention_heads, 95 | attention_head_dim, 96 | 97 | num_layers, 98 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 99 | dropout = 0.0, 100 | norm_num_groups = 32, 101 | cross_attention_dim = 768, 102 | activation_fn = "geglu", 103 | attention_bias = False, 104 | upcast_attention = False, 105 | 106 | cross_frame_attention_mode = None, 107 | temporal_position_encoding = False, 108 | temporal_position_encoding_max_len = 24, 109 | ): 110 | super().__init__() 111 | 112 | inner_dim = num_attention_heads * attention_head_dim 113 | 114 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 115 | self.proj_in = nn.Linear(in_channels, inner_dim) 116 | 117 | self.transformer_blocks = nn.ModuleList( 118 | [ 119 | TemporalTransformerBlock( 120 | dim=inner_dim, 121 | num_attention_heads=num_attention_heads, 122 | attention_head_dim=attention_head_dim, 123 | attention_block_types=attention_block_types, 124 | dropout=dropout, 125 | norm_num_groups=norm_num_groups, 126 | cross_attention_dim=cross_attention_dim, 127 | activation_fn=activation_fn, 128 | attention_bias=attention_bias, 129 | upcast_attention=upcast_attention, 130 | cross_frame_attention_mode=cross_frame_attention_mode, 131 | temporal_position_encoding=temporal_position_encoding, 132 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 133 | ) 134 | for d in range(num_layers) 135 | ] 136 | ) 137 | self.proj_out = nn.Linear(inner_dim, in_channels) 138 | 139 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 140 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 141 | video_length = hidden_states.shape[2] 142 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 143 | 144 | batch, channel, height, weight = hidden_states.shape 145 | residual = hidden_states 146 | 147 | hidden_states = self.norm(hidden_states) 148 | inner_dim = hidden_states.shape[1] 149 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 150 | hidden_states = self.proj_in(hidden_states) 151 | 152 | # Transformer Blocks 153 | for block in self.transformer_blocks: 154 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 155 | 156 | # output 157 | hidden_states = self.proj_out(hidden_states) 158 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 159 | 160 | output = hidden_states + residual 161 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 162 | 163 | return output 164 | 165 | 166 | class TemporalTransformerBlock(nn.Module): 167 | def __init__( 168 | self, 169 | dim, 170 | num_attention_heads, 171 | attention_head_dim, 172 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 173 | dropout = 0.0, 174 | norm_num_groups = 32, 175 | cross_attention_dim = 768, 176 | activation_fn = "geglu", 177 | attention_bias = False, 178 | upcast_attention = False, 179 | cross_frame_attention_mode = None, 180 | temporal_position_encoding = False, 181 | temporal_position_encoding_max_len = 24, 182 | ): 183 | super().__init__() 184 | 185 | attention_blocks = [] 186 | norms = [] 187 | 188 | for block_name in attention_block_types: 189 | attention_blocks.append( 190 | VersatileAttention( 191 | attention_mode=block_name.split("_")[0], 192 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 193 | 194 | query_dim=dim, 195 | heads=num_attention_heads, 196 | dim_head=attention_head_dim, 197 | dropout=dropout, 198 | bias=attention_bias, 199 | upcast_attention=upcast_attention, 200 | 201 | cross_frame_attention_mode=cross_frame_attention_mode, 202 | temporal_position_encoding=temporal_position_encoding, 203 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 204 | ) 205 | ) 206 | norms.append(nn.LayerNorm(dim)) 207 | 208 | self.attention_blocks = nn.ModuleList(attention_blocks) 209 | self.norms = nn.ModuleList(norms) 210 | 211 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 212 | self.ff_norm = nn.LayerNorm(dim) 213 | 214 | 215 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 216 | for attention_block, norm in zip(self.attention_blocks, self.norms): 217 | norm_hidden_states = norm(hidden_states) 218 | hidden_states = attention_block( 219 | norm_hidden_states, 220 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 221 | video_length=video_length, 222 | ) + hidden_states 223 | 224 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 225 | 226 | output = hidden_states 227 | return output 228 | 229 | 230 | class PositionalEncoding(nn.Module): 231 | def __init__( 232 | self, 233 | d_model, 234 | dropout = 0., 235 | max_len = 24 236 | ): 237 | super().__init__() 238 | self.dropout = nn.Dropout(p=dropout) 239 | position = torch.arange(max_len).unsqueeze(1) 240 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 241 | pe = torch.zeros(1, max_len, d_model) 242 | pe[0, :, 0::2] = torch.sin(position * div_term) 243 | pe[0, :, 1::2] = torch.cos(position * div_term) 244 | self.register_buffer('pe', pe) 245 | 246 | def forward(self, x): 247 | x = x + self.pe[:, :x.size(1)] 248 | return self.dropout(x) 249 | 250 | 251 | class VersatileAttention(CrossAttention): 252 | def __init__( 253 | self, 254 | attention_mode = None, 255 | cross_frame_attention_mode = None, 256 | temporal_position_encoding = False, 257 | temporal_position_encoding_max_len = 24, 258 | *args, **kwargs 259 | ): 260 | super().__init__(*args, **kwargs) 261 | assert attention_mode == "Temporal" 262 | 263 | self.attention_mode = attention_mode 264 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 265 | 266 | self.pos_encoder = PositionalEncoding( 267 | kwargs["query_dim"], 268 | dropout=0., 269 | max_len=temporal_position_encoding_max_len 270 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 271 | 272 | def extra_repr(self): 273 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 274 | 275 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 276 | batch_size, sequence_length, _ = hidden_states.shape 277 | 278 | if self.attention_mode == "Temporal": 279 | d = hidden_states.shape[1] 280 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 281 | 282 | if self.pos_encoder is not None: 283 | hidden_states = self.pos_encoder(hidden_states) 284 | 285 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 286 | else: 287 | raise NotImplementedError 288 | 289 | encoder_hidden_states = encoder_hidden_states 290 | 291 | if self.group_norm is not None: 292 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 293 | 294 | query = self.to_q(hidden_states) 295 | dim = query.shape[-1] 296 | query = self.reshape_heads_to_batch_dim(query) 297 | 298 | if self.added_kv_proj_dim is not None: 299 | raise NotImplementedError 300 | 301 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 302 | key = self.to_k(encoder_hidden_states) 303 | value = self.to_v(encoder_hidden_states) 304 | 305 | key = self.reshape_heads_to_batch_dim(key) 306 | value = self.reshape_heads_to_batch_dim(value) 307 | 308 | if attention_mask is not None: 309 | if attention_mask.shape[-1] != query.shape[1]: 310 | target_length = query.shape[1] 311 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 312 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 313 | 314 | # attention, what we cannot get enough of 315 | if self._use_memory_efficient_attention_xformers: 316 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 317 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 318 | hidden_states = hidden_states.to(query.dtype) 319 | else: 320 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 321 | hidden_states = self._attention(query, key, value, attention_mask) 322 | else: 323 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 324 | 325 | # linear proj 326 | hidden_states = self.to_out[0](hidden_states) 327 | 328 | # dropout 329 | hidden_states = self.to_out[1](hidden_states) 330 | 331 | if self.attention_mode == "Temporal": 332 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 333 | 334 | return hidden_states 335 | -------------------------------------------------------------------------------- /magicanimate/models/resnet.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/guoyww/AnimateDiff 8 | 9 | # Copyright 2023 The HuggingFace Team. All rights reserved. 10 | # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. 11 | # 12 | # Licensed under the Apache License, Version 2.0 (the "License"); 13 | # you may not use this file except in compliance with the License. 14 | # You may obtain a copy of the License at 15 | # 16 | # http://www.apache.org/licenses/LICENSE-2.0 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | from einops import rearrange 28 | 29 | 30 | class InflatedConv3d(nn.Conv2d): 31 | def forward(self, x): 32 | video_length = x.shape[2] 33 | 34 | x = rearrange(x, "b c f h w -> (b f) c h w") 35 | x = super().forward(x) 36 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 37 | 38 | return x 39 | 40 | 41 | class Upsample3D(nn.Module): 42 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 43 | super().__init__() 44 | self.channels = channels 45 | self.out_channels = out_channels or channels 46 | self.use_conv = use_conv 47 | self.use_conv_transpose = use_conv_transpose 48 | self.name = name 49 | 50 | conv = None 51 | if use_conv_transpose: 52 | raise NotImplementedError 53 | elif use_conv: 54 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 55 | 56 | def forward(self, hidden_states, output_size=None): 57 | assert hidden_states.shape[1] == self.channels 58 | 59 | if self.use_conv_transpose: 60 | raise NotImplementedError 61 | 62 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 63 | dtype = hidden_states.dtype 64 | if dtype == torch.bfloat16: 65 | hidden_states = hidden_states.to(torch.float32) 66 | 67 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 68 | if hidden_states.shape[0] >= 64: 69 | hidden_states = hidden_states.contiguous() 70 | 71 | # if `output_size` is passed we force the interpolation output 72 | # size and do not make use of `scale_factor=2` 73 | if output_size is None: 74 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 75 | else: 76 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 77 | 78 | # If the input is bfloat16, we cast back to bfloat16 79 | if dtype == torch.bfloat16: 80 | hidden_states = hidden_states.to(dtype) 81 | 82 | hidden_states = self.conv(hidden_states) 83 | 84 | return hidden_states 85 | 86 | 87 | class Downsample3D(nn.Module): 88 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 89 | super().__init__() 90 | self.channels = channels 91 | self.out_channels = out_channels or channels 92 | self.use_conv = use_conv 93 | self.padding = padding 94 | stride = 2 95 | self.name = name 96 | 97 | if use_conv: 98 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 99 | else: 100 | raise NotImplementedError 101 | 102 | def forward(self, hidden_states): 103 | assert hidden_states.shape[1] == self.channels 104 | if self.use_conv and self.padding == 0: 105 | raise NotImplementedError 106 | 107 | assert hidden_states.shape[1] == self.channels 108 | hidden_states = self.conv(hidden_states) 109 | 110 | return hidden_states 111 | 112 | 113 | class ResnetBlock3D(nn.Module): 114 | def __init__( 115 | self, 116 | *, 117 | in_channels, 118 | out_channels=None, 119 | conv_shortcut=False, 120 | dropout=0.0, 121 | temb_channels=512, 122 | groups=32, 123 | groups_out=None, 124 | pre_norm=True, 125 | eps=1e-6, 126 | non_linearity="swish", 127 | time_embedding_norm="default", 128 | output_scale_factor=1.0, 129 | use_in_shortcut=None, 130 | ): 131 | super().__init__() 132 | self.pre_norm = pre_norm 133 | self.pre_norm = True 134 | self.in_channels = in_channels 135 | out_channels = in_channels if out_channels is None else out_channels 136 | self.out_channels = out_channels 137 | self.use_conv_shortcut = conv_shortcut 138 | self.time_embedding_norm = time_embedding_norm 139 | self.output_scale_factor = output_scale_factor 140 | 141 | if groups_out is None: 142 | groups_out = groups 143 | 144 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 145 | 146 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 147 | 148 | if temb_channels is not None: 149 | if self.time_embedding_norm == "default": 150 | time_emb_proj_out_channels = out_channels 151 | elif self.time_embedding_norm == "scale_shift": 152 | time_emb_proj_out_channels = out_channels * 2 153 | else: 154 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 155 | 156 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 157 | else: 158 | self.time_emb_proj = None 159 | 160 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 161 | self.dropout = torch.nn.Dropout(dropout) 162 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 163 | 164 | if non_linearity == "swish": 165 | self.nonlinearity = lambda x: F.silu(x) 166 | elif non_linearity == "mish": 167 | self.nonlinearity = Mish() 168 | elif non_linearity == "silu": 169 | self.nonlinearity = nn.SiLU() 170 | 171 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 172 | 173 | self.conv_shortcut = None 174 | if self.use_in_shortcut: 175 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 176 | 177 | def forward(self, input_tensor, temb): 178 | hidden_states = input_tensor 179 | 180 | hidden_states = self.norm1(hidden_states) 181 | hidden_states = self.nonlinearity(hidden_states) 182 | 183 | hidden_states = self.conv1(hidden_states) 184 | 185 | if temb is not None: 186 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 187 | 188 | if temb is not None and self.time_embedding_norm == "default": 189 | hidden_states = hidden_states + temb 190 | 191 | hidden_states = self.norm2(hidden_states) 192 | 193 | if temb is not None and self.time_embedding_norm == "scale_shift": 194 | scale, shift = torch.chunk(temb, 2, dim=1) 195 | hidden_states = hidden_states * (1 + scale) + shift 196 | 197 | hidden_states = self.nonlinearity(hidden_states) 198 | 199 | hidden_states = self.dropout(hidden_states) 200 | hidden_states = self.conv2(hidden_states) 201 | 202 | if self.conv_shortcut is not None: 203 | input_tensor = self.conv_shortcut(input_tensor) 204 | 205 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 206 | 207 | return output_tensor 208 | 209 | 210 | class Mish(torch.nn.Module): 211 | def forward(self, hidden_states): 212 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /magicanimate/pipelines/animation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ByteDance and/or its affiliates. 2 | # 3 | # Copyright (2023) MagicAnimate Authors 4 | # 5 | # ByteDance, its affiliates and licensors retain all intellectual 6 | # property and proprietary rights in and to this material, related 7 | # documentation and any modifications thereto. Any use, reproduction, 8 | # disclosure or distribution of this material and related documentation 9 | # without an express license agreement from ByteDance or 10 | # its affiliates is strictly prohibited. 11 | import argparse 12 | import datetime 13 | import inspect 14 | import os 15 | import random 16 | import numpy as np 17 | 18 | from PIL import Image 19 | from omegaconf import OmegaConf 20 | from collections import OrderedDict 21 | 22 | import torch 23 | import torch.distributed as dist 24 | 25 | from diffusers import AutoencoderKL, DDIMScheduler 26 | 27 | from tqdm import tqdm 28 | from transformers import CLIPTextModel, CLIPTokenizer 29 | 30 | from magicanimate.models.unet_controlnet import UNet3DConditionModel 31 | from magicanimate.models.controlnet import ControlNetModel 32 | from magicanimate.models.appearance_encoder import AppearanceEncoderModel 33 | from magicanimate.models.mutual_self_attention import ReferenceAttentionControl 34 | from magicanimate.pipelines.pipeline_animation import AnimationPipeline 35 | from magicanimate.utils.util import save_videos_grid 36 | from magicanimate.utils.dist_tools import distributed_init 37 | from accelerate.utils import set_seed 38 | 39 | from magicanimate.utils.videoreader import VideoReader 40 | 41 | from einops import rearrange 42 | 43 | from pathlib import Path 44 | 45 | 46 | def main(args): 47 | 48 | *_, func_args = inspect.getargvalues(inspect.currentframe()) 49 | func_args = dict(func_args) 50 | 51 | config = OmegaConf.load(args.config) 52 | 53 | # Initialize distributed training 54 | device = torch.device(f"cuda:{args.rank}") 55 | dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} 56 | 57 | if config.savename is None: 58 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 59 | savedir = f"samples/{Path(args.config).stem}-{time_str}" 60 | else: 61 | savedir = f"samples/{config.savename}" 62 | 63 | if args.dist: 64 | dist.broadcast_object_list([savedir], 0) 65 | dist.barrier() 66 | 67 | if args.rank == 0: 68 | os.makedirs(savedir, exist_ok=True) 69 | 70 | inference_config = OmegaConf.load(config.inference_config) 71 | 72 | motion_module = config.motion_module 73 | 74 | ### >>> create animation pipeline >>> ### 75 | tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") 76 | text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") 77 | if config.pretrained_unet_path: 78 | unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) 79 | else: 80 | unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) 81 | appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device) 82 | reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) 83 | reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) 84 | if config.pretrained_vae_path is not None: 85 | vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) 86 | else: 87 | vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") 88 | 89 | ### Load controlnet 90 | controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) 91 | 92 | unet.enable_xformers_memory_efficient_attention() 93 | appearance_encoder.enable_xformers_memory_efficient_attention() 94 | controlnet.enable_xformers_memory_efficient_attention() 95 | 96 | vae.to(torch.float16) 97 | unet.to(torch.float16) 98 | text_encoder.to(torch.float16) 99 | appearance_encoder.to(torch.float16) 100 | controlnet.to(torch.float16) 101 | 102 | pipeline = AnimationPipeline( 103 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, 104 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), 105 | # NOTE: UniPCMultistepScheduler 106 | ) 107 | 108 | # 1. unet ckpt 109 | # 1.1 motion module 110 | motion_module_state_dict = torch.load(motion_module, map_location="cpu") 111 | if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) 112 | motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict 113 | try: 114 | # extra steps for self-trained models 115 | state_dict = OrderedDict() 116 | for key in motion_module_state_dict.keys(): 117 | if key.startswith("module."): 118 | _key = key.split("module.")[-1] 119 | state_dict[_key] = motion_module_state_dict[key] 120 | else: 121 | state_dict[key] = motion_module_state_dict[key] 122 | motion_module_state_dict = state_dict 123 | del state_dict 124 | missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) 125 | assert len(unexpected) == 0 126 | except: 127 | _tmp_ = OrderedDict() 128 | for key in motion_module_state_dict.keys(): 129 | if "motion_modules" in key: 130 | if key.startswith("unet."): 131 | _key = key.split('unet.')[-1] 132 | _tmp_[_key] = motion_module_state_dict[key] 133 | else: 134 | _tmp_[key] = motion_module_state_dict[key] 135 | missing, unexpected = unet.load_state_dict(_tmp_, strict=False) 136 | assert len(unexpected) == 0 137 | del _tmp_ 138 | del motion_module_state_dict 139 | 140 | pipeline.to(device) 141 | ### <<< create validation pipeline <<< ### 142 | 143 | random_seeds = config.get("seed", [-1]) 144 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) 145 | random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds 146 | 147 | # input test videos (either source video/ conditions) 148 | 149 | test_videos = config.video_path 150 | source_images = config.source_image 151 | num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps) 152 | 153 | # read size, step from yaml file 154 | sizes = [config.size] * len(test_videos) 155 | steps = [config.S] * len(test_videos) 156 | 157 | config.random_seed = [] 158 | prompt = n_prompt = "" 159 | for idx, (source_image, test_video, random_seed, size, step) in tqdm( 160 | enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), 161 | total=len(test_videos), 162 | disable=(args.rank!=0) 163 | ): 164 | samples_per_video = [] 165 | samples_per_clip = [] 166 | # manually set random seed for reproduction 167 | if random_seed != -1: 168 | torch.manual_seed(random_seed) 169 | set_seed(random_seed) 170 | else: 171 | torch.seed() 172 | config.random_seed.append(torch.initial_seed()) 173 | 174 | if test_video.endswith('.mp4'): 175 | control = VideoReader(test_video).read() 176 | if control[0].shape[0] != size: 177 | control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] 178 | if config.max_length is not None: 179 | control = control[config.offset: (config.offset+config.max_length)] 180 | control = np.array(control) 181 | 182 | if source_image.endswith(".mp4"): 183 | source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size))) 184 | else: 185 | source_image = np.array(Image.open(source_image).resize((size, size))) 186 | H, W, C = source_image.shape 187 | 188 | print(f"current seed: {torch.initial_seed()}") 189 | init_latents = None 190 | 191 | # print(f"sampling {prompt} ...") 192 | original_length = control.shape[0] 193 | if control.shape[0] % config.L > 0: 194 | control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge') 195 | generator = torch.Generator(device=torch.device("cuda:0")) 196 | generator.manual_seed(torch.initial_seed()) 197 | sample = pipeline( 198 | prompt, 199 | negative_prompt = n_prompt, 200 | num_inference_steps = config.steps, 201 | guidance_scale = config.guidance_scale, 202 | width = W, 203 | height = H, 204 | video_length = len(control), 205 | controlnet_condition = control, 206 | init_latents = init_latents, 207 | generator = generator, 208 | num_actual_inference_steps = num_actual_inference_steps, 209 | appearance_encoder = appearance_encoder, 210 | reference_control_writer = reference_control_writer, 211 | reference_control_reader = reference_control_reader, 212 | source_image = source_image, 213 | **dist_kwargs, 214 | ).videos 215 | 216 | if args.rank == 0: 217 | source_images = np.array([source_image] * original_length) 218 | source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 219 | samples_per_video.append(source_images) 220 | 221 | control = control / 255.0 222 | control = rearrange(control, "t h w c -> 1 c t h w") 223 | control = torch.from_numpy(control) 224 | samples_per_video.append(control[:, :, :original_length]) 225 | 226 | samples_per_video.append(sample[:, :, :original_length]) 227 | 228 | samples_per_video = torch.cat(samples_per_video) 229 | 230 | video_name = os.path.basename(test_video)[:-4] 231 | source_name = os.path.basename(config.source_image[idx]).split(".")[0] 232 | save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4") 233 | save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4") 234 | 235 | if config.save_individual_videos: 236 | save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4") 237 | save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4") 238 | 239 | if args.dist: 240 | dist.barrier() 241 | 242 | if args.rank == 0: 243 | OmegaConf.save(config, f"{savedir}/config.yaml") 244 | 245 | 246 | def distributed_main(device_id, args): 247 | args.rank = device_id 248 | args.device_id = device_id 249 | if torch.cuda.is_available(): 250 | torch.cuda.set_device(args.device_id) 251 | torch.cuda.init() 252 | distributed_init(args) 253 | main(args) 254 | 255 | 256 | def run(args): 257 | 258 | if args.dist: 259 | args.world_size = max(1, torch.cuda.device_count()) 260 | assert args.world_size <= torch.cuda.device_count() 261 | 262 | if args.world_size > 0 and torch.cuda.device_count() > 1: 263 | port = random.randint(10000, 20000) 264 | args.init_method = f"tcp://localhost:{port}" 265 | torch.multiprocessing.spawn( 266 | fn=distributed_main, 267 | args=(args,), 268 | nprocs=args.world_size, 269 | ) 270 | else: 271 | main(args) 272 | 273 | 274 | if __name__ == "__main__": 275 | parser = argparse.ArgumentParser() 276 | parser.add_argument("--config", type=str, required=True) 277 | parser.add_argument("--dist", action="store_true", required=False) 278 | parser.add_argument("--rank", type=int, default=0, required=False) 279 | parser.add_argument("--world_size", type=int, default=1, required=False) 280 | 281 | args = parser.parse_args() 282 | run(args) 283 | -------------------------------------------------------------------------------- /magicanimate/pipelines/context.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main 8 | import numpy as np 9 | from typing import Callable, Optional, List 10 | 11 | 12 | def ordered_halving(val): 13 | bin_str = f"{val:064b}" 14 | bin_flip = bin_str[::-1] 15 | as_int = int(bin_flip, 2) 16 | 17 | return as_int / (1 << 64) 18 | 19 | 20 | def uniform( 21 | step: int = ..., 22 | num_steps: Optional[int] = None, 23 | num_frames: int = ..., 24 | context_size: Optional[int] = None, 25 | context_stride: int = 3, 26 | context_overlap: int = 4, 27 | closed_loop: bool = True, 28 | ): 29 | if num_frames <= context_size: 30 | yield list(range(num_frames)) 31 | return 32 | 33 | context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) 34 | 35 | for context_step in 1 << np.arange(context_stride): 36 | pad = int(round(num_frames * ordered_halving(step))) 37 | for j in range( 38 | int(ordered_halving(step) * context_step) + pad, 39 | num_frames + pad + (0 if closed_loop else -context_overlap), 40 | (context_size * context_step - context_overlap), 41 | ): 42 | yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] 43 | 44 | 45 | def get_context_scheduler(name: str) -> Callable: 46 | if name == "uniform": 47 | return uniform 48 | else: 49 | raise ValueError(f"Unknown context_overlap policy {name}") 50 | 51 | 52 | def get_total_steps( 53 | scheduler, 54 | timesteps: List[int], 55 | num_steps: Optional[int] = None, 56 | num_frames: int = ..., 57 | context_size: Optional[int] = None, 58 | context_stride: int = 3, 59 | context_overlap: int = 4, 60 | closed_loop: bool = True, 61 | ): 62 | return sum( 63 | len( 64 | list( 65 | scheduler( 66 | i, 67 | num_steps, 68 | num_frames, 69 | context_size, 70 | context_stride, 71 | context_overlap, 72 | ) 73 | ) 74 | ) 75 | for i in range(len(timesteps)) 76 | ) 77 | -------------------------------------------------------------------------------- /magicanimate/utils/dist_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 ByteDance and/or its affiliates. 2 | # 3 | # Copyright (2023) MagicAnimate Authors 4 | # 5 | # ByteDance, its affiliates and licensors retain all intellectual 6 | # property and proprietary rights in and to this material, related 7 | # documentation and any modifications thereto. Any use, reproduction, 8 | # disclosure or distribution of this material and related documentation 9 | # without an express license agreement from ByteDance or 10 | # its affiliates is strictly prohibited. 11 | import os 12 | import socket 13 | import warnings 14 | import torch 15 | from torch import distributed as dist 16 | 17 | 18 | def distributed_init(args): 19 | 20 | if dist.is_initialized(): 21 | warnings.warn("Distributed is already initialized, cannot initialize twice!") 22 | args.rank = dist.get_rank() 23 | else: 24 | print( 25 | f"Distributed Init (Rank {args.rank}): " 26 | f"{args.init_method}" 27 | ) 28 | dist.init_process_group( 29 | backend='nccl', 30 | init_method=args.init_method, 31 | world_size=args.world_size, 32 | rank=args.rank, 33 | ) 34 | print( 35 | f"Initialized Host {socket.gethostname()} as Rank " 36 | f"{args.rank}" 37 | ) 38 | 39 | if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ: 40 | # Set for onboxdataloader support 41 | split = args.init_method.split("//") 42 | assert len(split) == 2, ( 43 | "host url for distributed should be split by '//' " 44 | + "into exactly two elements" 45 | ) 46 | 47 | split = split[1].split(":") 48 | assert ( 49 | len(split) == 2 50 | ), "host url should be of the form :" 51 | os.environ["MASTER_ADDR"] = split[0] 52 | os.environ["MASTER_PORT"] = split[1] 53 | 54 | # perform a dummy all-reduce to initialize the NCCL communicator 55 | dist.all_reduce(torch.zeros(1).cuda()) 56 | 57 | suppress_output(is_master()) 58 | args.rank = dist.get_rank() 59 | return args.rank 60 | 61 | 62 | def get_rank(): 63 | if not dist.is_available(): 64 | return 0 65 | if not dist.is_nccl_available(): 66 | return 0 67 | if not dist.is_initialized(): 68 | return 0 69 | return dist.get_rank() 70 | 71 | 72 | def is_master(): 73 | return get_rank() == 0 74 | 75 | 76 | def synchronize(): 77 | if dist.is_initialized(): 78 | dist.barrier() 79 | 80 | 81 | def suppress_output(is_master): 82 | """Suppress printing on the current device. Force printing with `force=True`.""" 83 | import builtins as __builtin__ 84 | 85 | builtin_print = __builtin__.print 86 | 87 | def print(*args, **kwargs): 88 | force = kwargs.pop("force", False) 89 | if is_master or force: 90 | builtin_print(*args, **kwargs) 91 | 92 | __builtin__.print = print 93 | 94 | import warnings 95 | 96 | builtin_warn = warnings.warn 97 | 98 | def warn(*args, **kwargs): 99 | force = kwargs.pop("force", False) 100 | if is_master or force: 101 | builtin_warn(*args, **kwargs) 102 | 103 | # Log warnings only once 104 | warnings.warn = warn 105 | warnings.simplefilter("once", UserWarning) -------------------------------------------------------------------------------- /magicanimate/utils/util.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Adapted from https://github.com/guoyww/AnimateDiff 8 | import os 9 | import imageio 10 | import numpy as np 11 | 12 | import torch 13 | import torchvision 14 | 15 | from PIL import Image 16 | from typing import Union 17 | from tqdm import tqdm 18 | from einops import rearrange 19 | 20 | 21 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25): 22 | videos = rearrange(videos, "b c t h w -> t b c h w") 23 | outputs = [] 24 | for x in videos: 25 | x = torchvision.utils.make_grid(x, nrow=n_rows) 26 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 27 | if rescale: 28 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 29 | x = (x * 255).numpy().astype(np.uint8) 30 | outputs.append(x) 31 | 32 | os.makedirs(os.path.dirname(path), exist_ok=True) 33 | imageio.mimsave(path, outputs, fps=fps) 34 | 35 | def save_images_grid(images: torch.Tensor, path: str): 36 | assert images.shape[2] == 1 # no time dimension 37 | images = images.squeeze(2) 38 | grid = torchvision.utils.make_grid(images) 39 | grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8) 40 | os.makedirs(os.path.dirname(path), exist_ok=True) 41 | Image.fromarray(grid).save(path) 42 | 43 | # DDIM Inversion 44 | @torch.no_grad() 45 | def init_prompt(prompt, pipeline): 46 | uncond_input = pipeline.tokenizer( 47 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 48 | return_tensors="pt" 49 | ) 50 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 51 | text_input = pipeline.tokenizer( 52 | [prompt], 53 | padding="max_length", 54 | max_length=pipeline.tokenizer.model_max_length, 55 | truncation=True, 56 | return_tensors="pt", 57 | ) 58 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 59 | context = torch.cat([uncond_embeddings, text_embeddings]) 60 | 61 | return context 62 | 63 | 64 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 65 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 66 | timestep, next_timestep = min( 67 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 68 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 69 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 70 | beta_prod_t = 1 - alpha_prod_t 71 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 72 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 73 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 74 | return next_sample 75 | 76 | 77 | def get_noise_pred_single(latents, t, context, unet): 78 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 79 | return noise_pred 80 | 81 | 82 | @torch.no_grad() 83 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 84 | context = init_prompt(prompt, pipeline) 85 | uncond_embeddings, cond_embeddings = context.chunk(2) 86 | all_latent = [latent] 87 | latent = latent.clone().detach() 88 | for i in tqdm(range(num_inv_steps)): 89 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 90 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 91 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 92 | all_latent.append(latent) 93 | return all_latent 94 | 95 | 96 | @torch.no_grad() 97 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 98 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 99 | return ddim_latents 100 | 101 | 102 | def video2images(path, step=4, length=16, start=0): 103 | reader = imageio.get_reader(path) 104 | frames = [] 105 | for frame in reader: 106 | frames.append(np.array(frame)) 107 | frames = frames[start::step][:length] 108 | return frames 109 | 110 | 111 | def images2video(video, path, fps=8): 112 | imageio.mimsave(path, video, fps=fps) 113 | return 114 | 115 | 116 | tensor_interpolation = None 117 | 118 | def get_tensor_interpolation_method(): 119 | return tensor_interpolation 120 | 121 | def set_tensor_interpolation_method(is_slerp): 122 | global tensor_interpolation 123 | tensor_interpolation = slerp if is_slerp else linear 124 | 125 | def linear(v1, v2, t): 126 | return (1.0 - t) * v1 + t * v2 127 | 128 | def slerp( 129 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 130 | ) -> torch.Tensor: 131 | u0 = v0 / v0.norm() 132 | u1 = v1 / v1.norm() 133 | dot = (u0 * u1).sum() 134 | if dot.abs() > DOT_THRESHOLD: 135 | #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 136 | return (1.0 - t) * v0 + t * v1 137 | omega = dot.acos() 138 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() -------------------------------------------------------------------------------- /magicanimate/utils/videoreader.py: -------------------------------------------------------------------------------- 1 | # ************************************************************************* 2 | # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- 3 | # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- 4 | # ytedance Inc.. 5 | # ************************************************************************* 6 | 7 | # Copyright 2022 ByteDance and/or its affiliates. 8 | # 9 | # Copyright (2022) PV3D Authors 10 | # 11 | # ByteDance, its affiliates and licensors retain all intellectual 12 | # property and proprietary rights in and to this material, related 13 | # documentation and any modifications thereto. Any use, reproduction, 14 | # disclosure or distribution of this material and related documentation 15 | # without an express license agreement from ByteDance or 16 | # its affiliates is strictly prohibited. 17 | import av, gc 18 | import torch 19 | import warnings 20 | import numpy as np 21 | 22 | 23 | _CALLED_TIMES = 0 24 | _GC_COLLECTION_INTERVAL = 20 25 | 26 | 27 | # remove warnings 28 | av.logging.set_level(av.logging.ERROR) 29 | 30 | 31 | class VideoReader(): 32 | """ 33 | Simple wrapper around PyAV that exposes a few useful functions for 34 | dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries. 35 | Acknowledgement: Codes are borrowed from Bruno Korbar 36 | """ 37 | def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False): 38 | """ 39 | Arguments: 40 | video_path (str): path or byte of the video to be loaded 41 | """ 42 | self.container = av.open(video) 43 | self.num_frames = num_frames 44 | self.bi_frame = bi_frame 45 | 46 | self.resampler = None 47 | if audio_resample_rate is not None: 48 | self.resampler = av.AudioResampler(rate=audio_resample_rate) 49 | 50 | if self.container.streams.video: 51 | # enable multi-threaded video decoding 52 | if decode_lossy: 53 | warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning) 54 | self.container.streams.video[0].thread_type = 'AUTO' 55 | self.video_stream = self.container.streams.video[0] 56 | else: 57 | self.video_stream = None 58 | 59 | self.fps = self._get_video_frame_rate() 60 | 61 | def seek(self, pts, backward=True, any_frame=False): 62 | stream = self.video_stream 63 | self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream) 64 | 65 | def _occasional_gc(self): 66 | # there are a lot of reference cycles in PyAV, so need to manually call 67 | # the garbage collector from time to time 68 | global _CALLED_TIMES, _GC_COLLECTION_INTERVAL 69 | _CALLED_TIMES += 1 70 | if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: 71 | gc.collect() 72 | 73 | def _read_video(self, offset): 74 | self._occasional_gc() 75 | 76 | pts = self.container.duration * offset 77 | time_ = pts / float(av.time_base) 78 | self.container.seek(int(pts)) 79 | 80 | video_frames = [] 81 | count = 0 82 | for _, frame in enumerate(self._iter_frames()): 83 | if frame.pts * frame.time_base >= time_: 84 | video_frames.append(frame) 85 | if count >= self.num_frames - 1: 86 | break 87 | count += 1 88 | return video_frames 89 | 90 | def _iter_frames(self): 91 | for packet in self.container.demux(self.video_stream): 92 | for frame in packet.decode(): 93 | yield frame 94 | 95 | def _compute_video_stats(self): 96 | if self.video_stream is None or self.container is None: 97 | return 0 98 | num_of_frames = self.container.streams.video[0].frames 99 | if num_of_frames == 0: 100 | num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base) 101 | self.seek(0, backward=False) 102 | count = 0 103 | time_base = 512 104 | for p in self.container.decode(video=0): 105 | count = count + 1 106 | if count == 1: 107 | start_pts = p.pts 108 | elif count == 2: 109 | time_base = p.pts - start_pts 110 | break 111 | return start_pts, time_base, num_of_frames 112 | 113 | def _get_video_frame_rate(self): 114 | return float(self.container.streams.video[0].guessed_rate) 115 | 116 | def sample(self, debug=False): 117 | 118 | if self.container is None: 119 | raise RuntimeError('video stream not found') 120 | sample = dict() 121 | _, _, total_num_frames = self._compute_video_stats() 122 | offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item() 123 | video_frames = self._read_video(offset/total_num_frames) 124 | video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) 125 | sample["frames"] = video_frames 126 | sample["frame_idx"] = [offset] 127 | 128 | if self.bi_frame: 129 | frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)] 130 | frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)] 131 | frames.sort() 132 | video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]]) 133 | Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)] 134 | sample["frames"] = video_frames 135 | sample["real_t"] = torch.tensor(Ts, dtype=torch.float32) 136 | sample["frame_idx"] = [offset+min(frames), offset+max(frames)] 137 | return sample 138 | 139 | return sample 140 | 141 | def read_frames(self, frame_indices): 142 | self.num_frames = frame_indices[1] - frame_indices[0] 143 | video_frames = self._read_video(frame_indices[0]/self.get_num_frames()) 144 | video_frames = np.array([ 145 | np.uint8(video_frames[0].to_rgb().to_ndarray()), 146 | np.uint8(video_frames[-1].to_rgb().to_ndarray()) 147 | ]) 148 | return video_frames 149 | 150 | def read(self): 151 | video_frames = self._read_video(0) 152 | video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) 153 | return video_frames 154 | 155 | def get_num_frames(self): 156 | _, _, total_num_frames = self._compute_video_stats() 157 | return total_num_frames -------------------------------------------------------------------------------- /models/motionmodule.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from diffusers.utils import BaseOutput 8 | from diffusers.models.attention import FeedForward 9 | from diffusers.models.attention_processor import Attention 10 | from xformers.ops import memory_efficient_attention 11 | 12 | from einops import rearrange, repeat 13 | import math 14 | 15 | 16 | def zero_module(module): 17 | # Zero out the parameters of a module and return it. 18 | for p in module.parameters(): 19 | p.detach().zero_() 20 | return module 21 | 22 | 23 | @dataclass 24 | class TemporalTransformer3DModelOutput(BaseOutput): 25 | sample: torch.FloatTensor 26 | 27 | def get_motion_module( 28 | in_channels, 29 | motion_module_type: str, 30 | motion_module_kwargs: dict 31 | ): 32 | if motion_module_type == "Vanilla": 33 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 34 | else: 35 | raise ValueError 36 | 37 | 38 | class VanillaTemporalModule(nn.Module): 39 | def __init__( 40 | self, 41 | in_channels, 42 | num_attention_heads = 8, 43 | num_transformer_block = 2, 44 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 45 | cross_frame_attention_mode = None, 46 | temporal_position_encoding = False, 47 | temporal_position_encoding_max_len = 24, 48 | temporal_attention_dim_div = 1, 49 | zero_initialize = True, 50 | ): 51 | super().__init__() 52 | 53 | self.temporal_transformer = TemporalTransformer3DModel( 54 | in_channels=in_channels, 55 | num_attention_heads=num_attention_heads, 56 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 57 | num_layers=num_transformer_block, 58 | attention_block_types=attention_block_types, 59 | cross_frame_attention_mode=cross_frame_attention_mode, 60 | temporal_position_encoding=temporal_position_encoding, 61 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 62 | ) 63 | 64 | if zero_initialize: 65 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 66 | 67 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 68 | hidden_states = input_tensor 69 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 70 | 71 | output = hidden_states 72 | return output 73 | 74 | 75 | class TemporalTransformer3DModel(nn.Module): 76 | def __init__( 77 | self, 78 | in_channels, 79 | num_attention_heads, 80 | attention_head_dim, 81 | 82 | num_layers, 83 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 84 | dropout = 0.0, 85 | norm_num_groups = 32, 86 | cross_attention_dim = 768, 87 | activation_fn = "geglu", 88 | attention_bias = False, 89 | upcast_attention = False, 90 | 91 | cross_frame_attention_mode = None, 92 | temporal_position_encoding = False, 93 | temporal_position_encoding_max_len = 24, 94 | ): 95 | super().__init__() 96 | 97 | inner_dim = num_attention_heads * attention_head_dim 98 | 99 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 100 | self.proj_in = nn.Linear(in_channels, inner_dim) 101 | 102 | self.transformer_blocks = nn.ModuleList( 103 | [ 104 | TemporalTransformerBlock( 105 | dim=inner_dim, 106 | num_attention_heads=num_attention_heads, 107 | attention_head_dim=attention_head_dim, 108 | attention_block_types=attention_block_types, 109 | dropout=dropout, 110 | norm_num_groups=norm_num_groups, 111 | cross_attention_dim=cross_attention_dim, 112 | activation_fn=activation_fn, 113 | attention_bias=attention_bias, 114 | upcast_attention=upcast_attention, 115 | cross_frame_attention_mode=cross_frame_attention_mode, 116 | temporal_position_encoding=temporal_position_encoding, 117 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 118 | ) 119 | for d in range(num_layers) 120 | ] 121 | ) 122 | self.proj_out = nn.Linear(inner_dim, in_channels) 123 | 124 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 125 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 126 | video_length = hidden_states.shape[2] 127 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 128 | 129 | batch, channel, height, weight = hidden_states.shape 130 | residual = hidden_states 131 | 132 | hidden_states = self.norm(hidden_states) 133 | inner_dim = hidden_states.shape[1] 134 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 135 | hidden_states = self.proj_in(hidden_states) 136 | 137 | # Transformer Blocks 138 | for block in self.transformer_blocks: 139 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 140 | 141 | # output 142 | hidden_states = self.proj_out(hidden_states) 143 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 144 | 145 | output = hidden_states + residual 146 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 147 | 148 | return output 149 | 150 | 151 | class TemporalTransformerBlock(nn.Module): 152 | def __init__( 153 | self, 154 | dim, 155 | num_attention_heads, 156 | attention_head_dim, 157 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 158 | dropout = 0.0, 159 | norm_num_groups = 32, 160 | cross_attention_dim = 768, 161 | activation_fn = "geglu", 162 | attention_bias = False, 163 | upcast_attention = False, 164 | cross_frame_attention_mode = None, 165 | temporal_position_encoding = False, 166 | temporal_position_encoding_max_len = 24, 167 | ): 168 | super().__init__() 169 | 170 | attention_blocks = [] 171 | norms = [] 172 | 173 | for block_name in attention_block_types: 174 | attention_blocks.append( 175 | VersatileAttention( 176 | attention_mode=block_name.split("_")[0], 177 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 178 | 179 | query_dim=dim, 180 | heads=num_attention_heads, 181 | dim_head=attention_head_dim, 182 | dropout=dropout, 183 | bias=attention_bias, 184 | upcast_attention=upcast_attention, 185 | 186 | cross_frame_attention_mode=cross_frame_attention_mode, 187 | temporal_position_encoding=temporal_position_encoding, 188 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 189 | ) 190 | ) 191 | norms.append(nn.LayerNorm(dim)) 192 | 193 | self.attention_blocks = nn.ModuleList(attention_blocks) 194 | self.norms = nn.ModuleList(norms) 195 | 196 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 197 | self.ff_norm = nn.LayerNorm(dim) 198 | 199 | 200 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 201 | for attention_block, norm in zip(self.attention_blocks, self.norms): 202 | norm_hidden_states = norm(hidden_states) 203 | hidden_states = attention_block( 204 | norm_hidden_states, 205 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 206 | video_length=video_length, 207 | ) + hidden_states 208 | 209 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 210 | 211 | output = hidden_states 212 | return output 213 | 214 | 215 | class PositionalEncoding(nn.Module): 216 | def __init__( 217 | self, 218 | d_model, 219 | dropout = 0., 220 | max_len = 24 221 | ): 222 | super().__init__() 223 | self.dropout = nn.Dropout(p=dropout) 224 | position = torch.arange(max_len).unsqueeze(1) 225 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 226 | pe = torch.zeros(1, max_len, d_model) 227 | pe[0, :, 0::2] = torch.sin(position * div_term) 228 | pe[0, :, 1::2] = torch.cos(position * div_term) 229 | self.register_buffer('pe', pe) 230 | 231 | def forward(self, x): 232 | x = x + self.pe[:, :x.size(1)] 233 | return self.dropout(x) 234 | 235 | 236 | class VersatileAttention(Attention): 237 | def __init__( 238 | self, 239 | attention_mode = None, 240 | cross_frame_attention_mode = None, 241 | temporal_position_encoding = False, 242 | temporal_position_encoding_max_len = 24, 243 | *args, **kwargs 244 | ): 245 | super().__init__(*args, **kwargs) 246 | assert attention_mode == "Temporal" 247 | 248 | self.attention_mode = attention_mode 249 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 250 | 251 | self.pos_encoder = PositionalEncoding( 252 | kwargs["query_dim"], 253 | dropout=0., 254 | max_len=temporal_position_encoding_max_len 255 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 256 | 257 | def extra_repr(self): 258 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 259 | 260 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 261 | batch_size, sequence_length, _ = hidden_states.shape 262 | 263 | if self.attention_mode == "Temporal": 264 | d = hidden_states.shape[1] 265 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 266 | 267 | if self.pos_encoder is not None: 268 | hidden_states = self.pos_encoder(hidden_states) 269 | 270 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 271 | else: 272 | raise NotImplementedError 273 | 274 | encoder_hidden_states = encoder_hidden_states 275 | 276 | if self.group_norm is not None: 277 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 278 | 279 | query = self.to_q(hidden_states) 280 | dim = query.shape[-1] 281 | query = self.head_to_batch_dim(query) 282 | 283 | if self.added_kv_proj_dim is not None: 284 | raise NotImplementedError 285 | 286 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 287 | key = self.to_k(encoder_hidden_states) 288 | value = self.to_v(encoder_hidden_states) 289 | 290 | key = self.head_to_batch_dim(key) 291 | value = self.head_to_batch_dim(value) 292 | 293 | if attention_mask is not None: 294 | if attention_mask.shape[-1] != query.shape[1]: 295 | target_length = query.shape[1] 296 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 297 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 298 | 299 | # attention, what we cannot get enough of 300 | hidden_states = memory_efficient_attention(query, key, value) 301 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 302 | hidden_states = hidden_states.to(query.dtype) 303 | 304 | # combine heads back into one dimension 305 | hidden_states = self.batch_to_head_dim(hidden_states) 306 | 307 | # linear proj 308 | hidden_states = self.to_out[0](hidden_states) 309 | 310 | # dropout 311 | hidden_states = self.to_out[1](hidden_states) 312 | 313 | if self.attention_mode == "Temporal": 314 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 315 | 316 | return hidden_states -------------------------------------------------------------------------------- /models/videonet.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Optional, Dict, Any 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from diffusers.models import UNet2DConditionModel, Transformer2DModel 7 | from einops import rearrange 8 | from xformers.ops import memory_efficient_attention 9 | 10 | from .motionmodule import get_motion_module 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | # SpatialAttentionModule is a spatial attention module between reference and input 15 | class SpatialAttentionModule(nn.Module): 16 | def __init__(self, num_inp_channels: int, embed_dim: int = 40, num_heads: int = 8) -> None: 17 | super(SpatialAttentionModule, self).__init__() 18 | 19 | self.num_inp_channels = num_inp_channels 20 | self.embed_dim = embed_dim 21 | self.num_heads = num_heads 22 | 23 | # create input projection layers 24 | self.norm_in = nn.GroupNorm(num_groups=32, num_channels=num_inp_channels, eps=1e-6, affine=True) 25 | self.proj_in = nn.Conv2d(num_inp_channels, num_inp_channels, kernel_size=1, stride=1, padding=0) 26 | 27 | # create multiheaded attention module 28 | self.to_q = nn.Linear(num_inp_channels, embed_dim) 29 | self.to_k = nn.Linear(num_inp_channels, embed_dim) 30 | self.to_v = nn.Linear(num_inp_channels, embed_dim) 31 | self.norm1 = nn.LayerNorm(embed_dim) 32 | self.ffn = nn.Linear(embed_dim, embed_dim) 33 | self.norm2 = nn.LayerNorm(embed_dim) 34 | 35 | # create output projection layer 36 | self.proj_out = nn.Conv2d(num_inp_channels, num_inp_channels, kernel_size=1, stride=1, padding=0) 37 | 38 | # forward passes the activation through a spatial attention module 39 | def forward(self, x, reference_tensor): 40 | # expand and concat x with reference embedding where x is [b*t,c,h,w] 41 | orig_w = x.shape[3] 42 | concat = torch.cat((x, reference_tensor), axis=3) 43 | h, w = concat.shape[2], concat.shape[3] 44 | 45 | # pass data through input projections 46 | proj_x = self.norm_in(concat) 47 | proj_x = self.proj_in(proj_x) 48 | 49 | # re-arrange data from (b*t,c,h,w) to correct groupings to [b*t,w*h,c] 50 | grouped_x = rearrange(proj_x, 'bt c h w -> bt (h w) c') 51 | reshaped_x = rearrange(x, 'bt c h w -> bt (h w) c') 52 | 53 | # compute self-attention on the concatenated data along w dimension 54 | q, k, v = self.to_q(reshaped_x), self.to_k(grouped_x), self.to_v(grouped_x) 55 | 56 | # split embeddings for multi-headed attention 57 | q = rearrange(q, 'bt (h w) (n d) -> bt (h w) n d', h=x.shape[2], w=x.shape[3], n=self.num_heads) 58 | k = rearrange(k, 'bt (h w) (n d) -> bt (h w) n d', h=h, w=w, n=self.num_heads) 59 | v = rearrange(v, 'bt (h w) (n d) -> bt (h w) n d', h=h, w=w, n=self.num_heads) 60 | 61 | # run attention calculation 62 | attn_out = memory_efficient_attention(q, k, v) 63 | # reshape from multihead 64 | attn_out = rearrange(attn_out, 'bt (h w) n d -> bt (h w) (n d)', h=x.shape[2], w=x.shape[3], n=self.num_heads) 65 | 66 | norm1_out = self.norm1(attn_out + reshaped_x) 67 | ffn_out = self.ffn(norm1_out) 68 | attn_out = self.norm2(norm1_out + ffn_out) 69 | 70 | # re-arrange data from (b*t,w*h,c) to (b*t,c,h,w) 71 | attn_out = rearrange(attn_out, 'bt (h w) c -> bt c h w', h=x.shape[2], w=x.shape[3]) 72 | 73 | # pass output through out projection 74 | out = self.proj_out(attn_out) 75 | 76 | # return sliced out with x as adding residual before reshape would be the same as adding x 77 | return out + x 78 | 79 | 80 | # TemporalAttentionModule is a temporal attention module 81 | class TemporalAttentionModule(nn.Module): 82 | def __init__(self, num_inp_channels: int, num_frames: int, embed_dim: int = 40, num_heads: int = 8) -> None: 83 | super(TemporalAttentionModule, self).__init__() 84 | 85 | self.num_inp_channels = num_inp_channels 86 | self.num_frames = num_frames 87 | self.embed_dim = embed_dim 88 | 89 | # create input projection layers 90 | self.norm_in = nn.GroupNorm(num_groups=32, num_channels=num_inp_channels, eps=1e-6, affine=True) 91 | self.proj_in = nn.Conv2d(num_inp_channels, num_inp_channels, kernel_size=1, stride=1, padding=0) 92 | 93 | # create multiheaded attention module 94 | self.to_q = nn.Linear(num_inp_channels, embed_dim) 95 | self.to_k = nn.Linear(num_inp_channels, embed_dim) 96 | self.to_v = nn.Linear(num_inp_channels, embed_dim) 97 | self.norm1 = nn.LayerNorm(embed_dim) 98 | self.ffn = nn.Linear(embed_dim, embed_dim) 99 | self.norm2 = nn.LayerNorm(embed_dim) 100 | 101 | # create output projection layer 102 | self.proj_out = nn.Conv2d(num_inp_channels, num_inp_channels, kernel_size=1, stride=1, padding=0) 103 | 104 | # forward performs temporal attention on the input (b*t,c,h,w) 105 | def forward(self, x): 106 | h, w = x.shape[2], x.shape[3] 107 | 108 | # pass data through input projections 109 | proj_x = self.norm_in(x) 110 | proj_x = self.proj_in(proj_x) 111 | 112 | # re-arrange data from (b*t,c,h,w) to correct groupings to (b*t,w*h,c) 113 | grouped_x = rearrange(x, '(b t) c h w -> (b h w) t c', t=self.num_frames) 114 | 115 | # perform self-attention on the grouped_x 116 | q, k, v = self.to_q(grouped_x), self.to_k(grouped_x), self.to_v(grouped_x) 117 | attn_out = memory_efficient_attention(q, k, v) 118 | norm1_out = self.norm1(attn_out + grouped_x) 119 | ffn_out = self.ffn(norm1_out) 120 | attn_out = self.norm2(norm1_out + ffn_out) 121 | 122 | # rearrange out to be back into the grouped batch and timestep format 123 | attn_out = rearrange(attn_out, '(b h w) t c -> (b t) c h w', t=self.num_frames, h=h, w=w) 124 | 125 | # pass attention output through out projection 126 | attn_out = self.proj_out(attn_out) 127 | 128 | return attn_out + x 129 | 130 | 131 | # ReferenceConditionedAttentionBlock is an attention block which performs spatial and temporal attention 132 | class ReferenceConditionedAttentionBlock(nn.Module): 133 | def __init__(self, cross_attn: Transformer2DModel, num_frames: int, skip_temporal_attn: bool = False): 134 | super(ReferenceConditionedAttentionBlock, self).__init__() 135 | 136 | # store configurations and submodules 137 | self.skip_temporal_attn = skip_temporal_attn 138 | self.num_frames = num_frames 139 | self.cross_attn = cross_attn 140 | 141 | # extract channel dimension from provided cross_attn and 142 | num_channels = cross_attn.config.in_channels 143 | embed_dim = cross_attn.config.in_channels 144 | self.sam = SpatialAttentionModule(num_channels, embed_dim=embed_dim) 145 | self.tam = get_motion_module(num_channels, 146 | motion_module_type='Vanilla', 147 | motion_module_kwargs={}) 148 | 149 | # store the reference tensor used by this module (this must be updated before the forward pass) 150 | self.reference_tensor = None 151 | 152 | # update_reference_tensor updates the reference tensor for the module 153 | def update_reference_tensor(self, reference_tensor: torch.FloatTensor): 154 | self.reference_tensor = reference_tensor 155 | 156 | # update_num_frames updates the number of frames the temporal attention module is configured for 157 | def update_num_frames(self, num_frames: int): 158 | self.num_frames = num_frames 159 | 160 | # forward performs spatial attention, cross attention, and temporal attention 161 | def forward( 162 | self, 163 | hidden_states: torch.Tensor, 164 | encoder_hidden_states: Optional[torch.Tensor] = None, 165 | timestep: Optional[torch.LongTensor] = None, 166 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 167 | class_labels: Optional[torch.LongTensor] = None, 168 | cross_attention_kwargs: Dict[str, Any] = None, 169 | attention_mask: Optional[torch.Tensor] = None, 170 | encoder_attention_mask: Optional[torch.Tensor] = None, 171 | return_dict: bool = True, 172 | ): 173 | # begin spatial attention 174 | 175 | # pass concat tensor through spatial attention module along w axis [bt,c,h,w] 176 | out = self.sam(hidden_states, self.reference_tensor) 177 | 178 | # begin cross attention 179 | out = self.cross_attn(out, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, 180 | cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)[0] 181 | 182 | # begin temporal attention 183 | if self.skip_temporal_attn: 184 | return (out,) 185 | 186 | # reshape data from [bt c h w] to be [b c t h w] 187 | temporal_input = rearrange(out, '(b t) c h w -> b c t h w', t=self.num_frames) 188 | 189 | # pass the data through the temporal attention module 190 | temporal_output = self.tam(temporal_input, None, None) 191 | 192 | # reshape temporal output back from [b c t h w] to [bt c h w] 193 | temporal_output = rearrange(temporal_output, 'b c t h w -> (b t) c h w') 194 | 195 | return (temporal_output,) 196 | 197 | 198 | # VideoNet is a unet initialized from stable diffusion used to denoise video frames 199 | class VideoNet(nn.Module): 200 | def __init__(self, sd_unet: UNet2DConditionModel, num_frames: int = 24, batch_size: int = 2): 201 | super(VideoNet, self).__init__() 202 | self.batch_size = batch_size 203 | 204 | # create a deep copy of the sd_unet 205 | self.unet = copy.deepcopy(sd_unet) 206 | 207 | # maintain a list of all the new ReferenceConditionedResNets and TemporalAttentionBlocks 208 | self.ref_cond_attn_blocks: List[ReferenceConditionedAttentionBlock] = [] 209 | 210 | # replace attention blocks with ReferenceConditionedAttentionBlock 211 | down_blocks = self.unet.down_blocks 212 | mid_block = self.unet.mid_block 213 | up_blocks = self.unet.up_blocks 214 | 215 | for i in range(len(down_blocks)): 216 | if hasattr(down_blocks[i], "attentions"): 217 | attentions = down_blocks[i].attentions 218 | for j in range(len(attentions)): 219 | attentions[j] = ReferenceConditionedAttentionBlock(attentions[j], num_frames) 220 | self.ref_cond_attn_blocks.append(attentions[j]) 221 | 222 | for i in range(len(mid_block.attentions)): 223 | mid_block.attentions[i] = ReferenceConditionedAttentionBlock(mid_block.attentions[i], num_frames) 224 | self.ref_cond_attn_blocks.append(mid_block.attentions[i]) 225 | 226 | for i in range(len(up_blocks)): 227 | if hasattr(up_blocks[i], "attentions"): 228 | attentions = up_blocks[i].attentions 229 | for j in range(len(attentions)): 230 | attentions[j] = ReferenceConditionedAttentionBlock(attentions[j], num_frames) 231 | self.ref_cond_attn_blocks.append(attentions[j]) 232 | 233 | # update_reference_embeddings updates all the reference embeddings in the unet 234 | def update_reference_embeddings(self, reference_embeddings): 235 | if len(reference_embeddings) != len(self.ref_cond_attn_blocks): 236 | print("[!] WARNING - amount of input reference embeddings does not match number of modules in VideoNet") 237 | 238 | for i in range(len(self.ref_cond_attn_blocks)): 239 | # update the reference conditioned blocks embedding 240 | self.ref_cond_attn_blocks[i].update_reference_tensor(reference_embeddings[i]) 241 | 242 | # update_num_frames updates all temporal attention block frame number 243 | def update_num_frames(self, num_frames): 244 | for i in range(len(self.ref_cond_attn_blocks)): 245 | # update the number of frames 246 | self.ref_cond_attn_blocks[i].update_num_frames(num_frames) 247 | 248 | # update_skip_temporal_attn updates all the skip temporal attention attributes 249 | def update_skip_temporal_attn(self, skip_temporal_attn): 250 | for i in range(len(self.ref_cond_attn_blocks)): 251 | # update the skip_temporal_attn attribute 252 | self.ref_cond_attn_blocks[i].skip_temporal_attn = skip_temporal_attn 253 | 254 | # forward pass just passes pose + conditioning embeddings to unet and returns activations 255 | def forward(self, intial_noise, timesteps, reference_embeddings, clip_condition_embeddings, skip_temporal_attn=False): 256 | # update the reference tensors for the ReferenceConditionedResNet modules 257 | self.update_reference_embeddings(reference_embeddings) 258 | 259 | # update the skip temporal attention attribute 260 | self.update_skip_temporal_attn(skip_temporal_attn) 261 | 262 | # forward pass the pose + conditioning embeddings through the unet 263 | return self.unet( 264 | intial_noise, 265 | timesteps, 266 | encoder_hidden_states=clip_condition_embeddings, 267 | )[0] 268 | 269 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | moviepy 3 | -------------------------------------------------------------------------------- /train_stage_1_referencenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset, DataLoader 5 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | from pathlib import Path 9 | from typing import Dict, List, Tuple 10 | from omegaconf import OmegaConf 11 | 12 | class EMODatasetStage1(Dataset): 13 | """ 14 | Stage 1 dataset focused purely on frame encoding. 15 | Only provides single frames for training the ReferenceNet and VAE. 16 | """ 17 | def __init__( 18 | self, 19 | data_dir: str, 20 | video_dir: str, 21 | json_file: str, 22 | width: int = 512, 23 | height: int = 512, 24 | transform = None 25 | ): 26 | self.data_dir = Path(data_dir) 27 | self.video_dir = Path(video_dir) 28 | self.width = width 29 | self.height = height 30 | 31 | # Default transform if none provided 32 | self.transform = transform or transforms.Compose([ 33 | transforms.Resize((height, width)), 34 | transforms.ToTensor(), 35 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 36 | ]) 37 | 38 | # Load video metadata 39 | import json 40 | with open(json_file, 'r') as f: 41 | self.data = json.load(f) 42 | self.video_ids = list(self.data['clips'].keys()) 43 | 44 | def __len__(self) -> int: 45 | return len(self.video_ids) 46 | 47 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 48 | """ 49 | Returns a single frame for training. 50 | For Stage 1, we only need individual frames. 51 | """ 52 | video_id = self.video_ids[idx] 53 | video_path = self.video_dir / f"{video_id}.mp4" 54 | 55 | # Read a random frame from the video 56 | import cv2 57 | cap = cv2.VideoCapture(str(video_path)) 58 | 59 | # Get random frame 60 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 61 | target_frame = torch.randint(0, total_frames, (1,)).item() 62 | cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame) 63 | 64 | ret, frame = cap.read() 65 | cap.release() 66 | 67 | if not ret: 68 | raise ValueError(f"Could not read frame from video: {video_path}") 69 | 70 | # Convert BGR to RGB 71 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 72 | frame = Image.fromarray(frame) 73 | 74 | # Apply transforms 75 | frame_tensor = self.transform(frame) 76 | 77 | return { 78 | 'pixel_values': frame_tensor, 79 | 'video_id': video_id 80 | } 81 | 82 | class ReferenceNet(nn.Module): 83 | """ 84 | ReferenceNet: Extracts reference features from input frames. 85 | Based on SD UNet architecture but modified for reference feature extraction. 86 | """ 87 | def __init__(self, unet: UNet2DConditionModel): 88 | super().__init__() 89 | self.unet = unet 90 | 91 | # Freeze most UNet parameters except final blocks 92 | for name, param in self.unet.named_parameters(): 93 | if 'up_blocks.3' not in name: # Only fine-tune the final up block 94 | param.requires_grad = False 95 | 96 | def forward(self, latents: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: 97 | """Extract reference features through modified SD UNet.""" 98 | return self.unet(latents, timesteps, return_dict=False)[0] 99 | 100 | def train_stage1(config: OmegaConf) -> None: 101 | """ 102 | Stage 1 training focusing on frame encoding with ReferenceNet and VAE. 103 | """ 104 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 105 | 106 | # Initialize dataset 107 | dataset = EMODatasetStage1( 108 | data_dir=config.data.data_dir, 109 | video_dir=config.data.video_dir, 110 | json_file=config.data.json_file, 111 | width=config.data.train_width, 112 | height=config.data.train_height 113 | ) 114 | 115 | dataloader = DataLoader( 116 | dataset, 117 | batch_size=config.training.batch_size, 118 | shuffle=True, 119 | num_workers=config.training.num_workers 120 | ) 121 | 122 | # Initialize models 123 | # 1. VAE from Stable Diffusion 124 | vae = AutoencoderKL.from_pretrained( 125 | "stabilityai/sd-vae-ft-mse" 126 | ).to(device) 127 | vae.eval() # Freeze VAE weights 128 | 129 | # 2. UNet from Stable Diffusion for ReferenceNet 130 | reference_unet = UNet2DConditionModel.from_pretrained( 131 | "runwayml/stable-diffusion-v1-5", 132 | subfolder="unet" 133 | ).to(device) 134 | 135 | # 3. Initialize ReferenceNet 136 | reference_net = ReferenceNet(reference_unet).to(device) 137 | 138 | # Initialize optimizer (only for ReferenceNet) 139 | optimizer = torch.optim.AdamW( 140 | filter(lambda p: p.requires_grad, reference_net.parameters()), 141 | lr=config.training.learning_rate 142 | ) 143 | 144 | # Initialize noise scheduler 145 | noise_scheduler = DDPMScheduler( 146 | num_train_timesteps=1000, 147 | beta_start=0.00085, 148 | beta_end=0.012, 149 | beta_schedule="scaled_linear" 150 | ) 151 | 152 | # Training loop 153 | for epoch in range(config.training.num_epochs): 154 | total_loss = 0 155 | reference_net.train() 156 | 157 | for step, batch in enumerate(dataloader): 158 | # Get input images 159 | images = batch['pixel_values'].to(device) 160 | 161 | # Encode images to latent space using frozen VAE 162 | with torch.no_grad(): 163 | latents = vae.encode(images).latent_dist.sample() 164 | latents = latents * 0.18215 165 | 166 | # Add noise to latents 167 | noise = torch.randn_like(latents) 168 | timesteps = torch.randint( 169 | 0, noise_scheduler.config.num_train_timesteps, 170 | (images.shape[0],), device=device 171 | ).long() 172 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 173 | 174 | # Get model prediction 175 | noise_pred = reference_net(noisy_latents, timesteps) 176 | 177 | # Calculate loss 178 | loss = F.mse_loss(noise_pred, noise) 179 | 180 | # Backpropagation 181 | optimizer.zero_grad() 182 | loss.backward() 183 | optimizer.step() 184 | 185 | total_loss += loss.item() 186 | 187 | # Log progress 188 | if step % config.training.log_every == 0: 189 | print(f"Epoch {epoch+1}/{config.training.num_epochs}, " 190 | f"Step {step}/{len(dataloader)}, " 191 | f"Loss: {loss.item():.4f}") 192 | 193 | # Save checkpoint 194 | if (epoch + 1) % config.training.save_every == 0: 195 | checkpoint = { 196 | 'epoch': epoch, 197 | 'reference_net_state_dict': reference_net.state_dict(), 198 | 'optimizer_state_dict': optimizer.state_dict(), 199 | 'loss': total_loss / len(dataloader), 200 | } 201 | torch.save( 202 | checkpoint, 203 | f"{config.training.checkpoint_dir}/stage1_epoch_{epoch+1}.pt" 204 | ) 205 | 206 | if __name__ == "__main__": 207 | # Load config 208 | config = OmegaConf.load("configs/stage1.yaml") 209 | train_stage1(config) -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | # **************************************************************************** # 2 | # # 3 | # ::: :::::::: # 4 | # video.py :+: :+: :+: # 5 | # +:+ +:+ +:+ # 6 | # By: taston +#+ +:+ +#+ # 7 | # +#+#+#+#+#+ +#+ # 8 | # Created: 2023/04/25 10:00:46 by taston #+# #+# # 9 | # Updated: 2023/05/30 10:09:03 by taston ### ########.fr # 10 | # # 11 | # **************************************************************************** # 12 | 13 | import cv2 14 | 15 | class Video: 16 | """ 17 | A class representing a Video 18 | 19 | ... 20 | 21 | Attributes 22 | ---------- 23 | cap : 24 | 25 | filename : str 26 | path to video file 27 | fps : int 28 | framerate in frames per second 29 | height : int 30 | pixel height of video 31 | total_frames : int 32 | length of video in frames 33 | width : int 34 | pixel width of video 35 | writer : 36 | 37 | 38 | Methods 39 | ------- 40 | create_writer() 41 | creates video writer object 42 | get_dim() 43 | gets video dimensions 44 | get_fps() 45 | gets video fps 46 | get_length() 47 | gets length of video in frames 48 | """ 49 | def __init__(self, filename=None): 50 | if filename: 51 | self.filename = filename 52 | self.cap = self._open_vid() 53 | self.width, self.height = self.get_dim() 54 | self.total_frames = self.get_length() 55 | self.fps = self.get_fps() 56 | 57 | def _open_vid(self): 58 | ''' 59 | Open specified video file and create capture object 60 | ''' 61 | cap = cv2.VideoCapture(self.filename) 62 | return cap 63 | 64 | def create_writer(self): 65 | ''' 66 | Create writer object for recording videos based on chosen video. 67 | ''' 68 | self.writer = cv2.VideoWriter('calibration.mp4', 69 | cv2.VideoWriter_fourcc(*'mp4v'), 70 | self.fps, 71 | (self.width, self.height)) 72 | 73 | def get_dim(self): 74 | ''' 75 | Get video resolution 76 | ''' 77 | width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 78 | height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 79 | return width, height 80 | 81 | def get_length(self): 82 | ''' 83 | Get length of video in frames 84 | ''' 85 | total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 86 | return total_frames 87 | 88 | def get_fps(self): 89 | ''' 90 | Get framerate of video in fps 91 | ''' 92 | fps = round(self.cap.get(cv2.CAP_PROP_FPS)) 93 | return fps 94 | 95 | def __str__(self): 96 | return ('-'*60 + '\n' + 97 | 'Video data:' + '\n' + 98 | '-'*60 + '\n' + 99 | f'Video resolution: {self.width} x {self.height} pixels' + '\n' + 100 | f'Length of video: {self.total_frames} frames' + '\n' + 101 | f'Framerate: {self.fps} fps' + '\n' + 102 | '-'*60) 103 | # return 'Video' 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /videonet_animatediff.py: -------------------------------------------------------------------------------- 1 | import time 2 | from os.path import join 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.models import AutoencoderKL 7 | from diffusers.schedulers import PNDMScheduler 8 | from diffusers.image_processor import VaeImageProcessor 9 | from transformers import CLIPVisionModel, CLIPImageProcessor 10 | from torch.utils.data import DataLoader 11 | from einops import rearrange, repeat 12 | from accelerate import Accelerator, DistributedDataParallelKwargs 13 | import torch.nn.functional as F 14 | from tqdm import tqdm 15 | from diffusers import StableDiffusionImageVariationPipeline, StableDiffusionPipeline 16 | import copy 17 | from typing import List, Optional, Dict, Any 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from diffusers.models import UNet2DConditionModel, Transformer2DModel 22 | from einops import rearrange 23 | from xformers.ops import memory_efficient_attention 24 | from models.motionmodule import get_motion_module 25 | import tensorboard 26 | from torch.utils.tensorboard import SummaryWriter 27 | from models.videonet import VideoNet 28 | from magicanimate.models.unet_controlnet import UNet3DConditionModel 29 | torch.manual_seed(17) 30 | 31 | import pkg_resources 32 | from omegaconf import OmegaConf 33 | for entry_point in pkg_resources.iter_entry_points('tensorboard_plugins'): 34 | print("tensorboard_plugins:",entry_point.dist) 35 | 36 | 37 | 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | # load_mm loads a motion module into video net 41 | def load_mm(video_net: VideoNet, mm_state_dict): 42 | refactored_mm_state_dict = {} 43 | for key in mm_state_dict: 44 | key_split = key.split('.') 45 | 46 | # modify the key split to have the correct arguments (except first unet) 47 | key_split[2] = 'attentions' 48 | key_split.insert(4, 'tam') 49 | new_key = '.'.join(key_split) 50 | refactored_mm_state_dict[new_key] = mm_state_dict[key] 51 | 52 | # load the modified weights into video_net 53 | _, unexpected = video_net.unet.load_state_dict(refactored_mm_state_dict, strict=False) 54 | 55 | return 56 | 57 | 58 | """ Why sd-image-variations-diffusers? 59 | The concept of sd-image-variations-diffusers appears to differ from normal Stable Diffusion (SD) in the focus on generating variations of an existing image or theme. Here’s how it stands out: 60 | 61 | Purpose of Variations: While normal SD primarily generates images from textual descriptions starting from scratch, sd-image-variations-diffusers seems to specialize in creating different versions or slight modifications of an existing image. This can be particularly useful for exploring alternative possibilities, fine-tuning details, or generating multiple iterations of a concept. 62 | 63 | Control and Consistency: Generating variations likely involves maintaining certain aspects of the original image constant, such as the overall theme, composition, or key elements, while altering others. This differs from the usual SD process, where each new generation can result in widely different images even with similar text prompts. 64 | 65 | Technique and Process: The use of the term “diffusers” suggests a specific approach or technique within the diffusion model framework, perhaps focusing on controlled manipulation of the image generation process. This could involve sophisticated methods to ensure that the variations are coherent and aligned with the original image’s characteristics. 66 | 67 | Targeted Creativity: sd-image-variations-diffusers may provide tools for more targeted creativity, allowing artists and users to iterate on a concept or visual idea more precisely. This could be useful in scenarios where the initial concept is clear, but the execution requires experimentation with variations to find the ideal manifestation. 68 | 69 | In summary, the difference lies in the specific application and functionality of generating nuanced variations of an image, as opposed to generating entirely new images from text descriptions. 70 | """ 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | if __name__ == '__main__': 79 | num_frames = 24 80 | 81 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema") 82 | # construct pipe from imag evariation diffuser 83 | pipe = StableDiffusionImageVariationPipeline.from_pretrained('/media/2TB/ani/animate-anyone/pretrained_models/sd-image-variations-diffusers', revision="v2.0", vae=vae).to(device) 84 | 85 | video_net = VideoNet(pipe.unet, num_frames=num_frames).to("cuda") 86 | 87 | 88 | # load mm pretrained weights from animatediff 89 | load_mm(video_net, torch.load('/media/2TB/stable-diffusion-webui/extensions/sd-webui-animatediff/model/v3_sd15_mm.ckpt')) 90 | 91 | for name, module in video_net.named_modules(): 92 | print(f" name:{name} layer:{module.__class__.__name__}") 93 | 94 | 95 | # Step 2: Initialize the TensorBoard SummaryWriter 96 | # writer = SummaryWriter('runs/videonet_experiment') 97 | 98 | 99 | # inference_config = OmegaConf.load("configs/inference.yaml") 100 | 101 | # unet = UNet3DConditionModel.from_pretrained_2d('/media/2TB/ani/animate-anyone/pretrained_models/sd-image-variations-diffusers', subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda() 102 | 103 | # # Get the correct number of latent dimensions from the model's configuration 104 | # num_channels_latent = 4 # This should be verified from the model's configuration 105 | 106 | # # Create dummy data 107 | # batch_size = 1 108 | # frames = 16 109 | # height = 512 110 | # width = 512 111 | 112 | # # Create the sample tensor (noisy latent image) 113 | # sample = torch.randn(batch_size, num_channels_latent, frames, height, width).to(device) 114 | 115 | # # Create the timestep tensor 116 | # timestep = torch.tensor([1]).to(device) # Replace 50 with the desired timestep value 117 | 118 | # # Create the encoder hidden states tensor 119 | # encoder_hidden_states = torch.randn(batch_size, frames, 1).to(device) # Assuming 768 is the hidden size 120 | 121 | # # Create the class labels tensor (optional) 122 | # class_labels = None # Set to None if not using class conditioning 123 | 124 | # # Perform the forward pass 125 | # with torch.no_grad(): 126 | # output = unet(sample, timestep, None, class_labels) --------------------------------------------------------------------------------