├── README.md ├── assets ├── boy.jpeg ├── boy_cropped.jpg ├── combined_with_transitions_24.mp4 └── driving.mp4 ├── configs └── inference │ └── inference.yaml ├── crop_process.py ├── data_utils ├── __pycache__ │ ├── datasets_faceswap.cpython-310.pyc │ └── transfer_utils.cpython-310.pyc ├── datasets_faceswap.py └── transfer_utils.py ├── decalib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── deca_with_smirk.cpython-310.pyc ├── datasets │ ├── __pycache__ │ │ ├── datasets.cpython-310.pyc │ │ └── detectors.cpython-310.pyc │ ├── aflw2000.py │ ├── build_datasets.py │ ├── datasets.py │ ├── detectors.py │ ├── ethnicity.py │ ├── now.py │ ├── train_datasets.py │ ├── vggface.py │ └── vox.py ├── deca_with_smirk.py ├── models │ ├── FLAME.py │ ├── __pycache__ │ │ ├── FLAME.cpython-310.pyc │ │ ├── decoders.cpython-310.pyc │ │ ├── encoders.cpython-310.pyc │ │ ├── lbs.cpython-310.pyc │ │ └── resnet.cpython-310.pyc │ ├── decoders.py │ ├── encoders.py │ ├── frnet.py │ ├── lbs.py │ └── resnet.py ├── smirk │ ├── __pycache__ │ │ ├── mediapipe_utils.cpython-310.pyc │ │ └── smirk_encoder.cpython-310.pyc │ ├── face_landmarker.task │ ├── mediapipe_utils.py │ ├── smirk_encoder.py │ └── utils │ │ ├── masking.py │ │ └── utils.py ├── trainer.py └── utils │ ├── __pycache__ │ ├── config.cpython-310.pyc │ ├── renderer.cpython-310.pyc │ ├── rotation_converter.cpython-310.pyc │ ├── tensor_cropper.cpython-310.pyc │ └── util.cpython-310.pyc │ ├── config.py │ ├── lossfunc.py │ ├── rasterizer │ ├── INSTALL.md │ ├── __init__.py │ ├── setup.py │ ├── standard_rasterize_cuda.cpp │ └── standard_rasterize_cuda_kernel.cu │ ├── renderer.py │ ├── rotation_converter.py │ ├── tensor_cropper.py │ ├── trainer.py │ └── util.py ├── inference.py ├── models ├── __pycache__ │ ├── attention.cpython-310.pyc │ ├── exp_encoder.cpython-310.pyc │ ├── guidance_encoder.cpython-310.pyc │ ├── mgportrait_model.cpython-310.pyc │ ├── motion_module.cpython-310.pyc │ ├── mutual_self_attention.cpython-310.pyc │ ├── resnet.cpython-310.pyc │ ├── transformer_2d.cpython-310.pyc │ ├── transformer_3d.cpython-310.pyc │ ├── unet_2d_blocks.cpython-310.pyc │ ├── unet_2d_condition.cpython-310.pyc │ ├── unet_3d.cpython-310.pyc │ └── unet_3d_blocks.cpython-310.pyc ├── attention.py ├── exp_encoder.py ├── guidance_encoder.py ├── mgportrait_model.py ├── motion_module.py ├── mutual_self_attention.py ├── resnet.py ├── transformer_2d.py ├── transformer_3d.py ├── unet_2d_blocks.py ├── unet_2d_condition.py ├── unet_3d.py └── unet_3d_blocks.py ├── pipelines ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── context.cpython-310.pyc │ ├── pipe_utils.cpython-310.pyc │ └── pipeline_aggregation.cpython-310.pyc ├── context.py ├── pipe_utils.py └── pipeline_aggregation.py ├── render_and_transfer.py ├── requirements.txt └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc └── video_utils.cpython-310.pyc ├── download.py ├── fs.py ├── postprocess.py ├── tb_tracker.py ├── util.py └── video_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # MagicPortrait 2 | 3 | **MagicPortrait: Temporally Consistent Face Reenactment with 3D Geometric Guidance** 4 |
5 | [Mengting Wei](), 6 | [Yante Li](), 7 | [Tuomas Varanka](), 8 | [Yan Jiang](), 9 | [Guoying Zhao]() 10 |
11 | 12 | _[arXiv](https://arxiv.org/abs/2504.21497) | [Model](https://huggingface.co/mengtingwei/MagicPortrait)_ 13 | 14 | This repository contains the example inference script for the MagicPortrait-preview model. 15 | 16 | https://github.com/user-attachments/assets/9471fdd9-948a-47dd-a632-2adfc631be50 17 | 18 | ## Installation 19 | 20 | ```bash 21 | conda create -n mgp python=3.10 -y 22 | conda activate mgp 23 | pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121 24 | pip install -r requirements.txt 25 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 26 | pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/download.html 27 | ``` 28 | 29 | ## Inference of the Model 30 | 31 | 32 | ### Step 1: Download pre-trained models 33 | 34 | Download our models from Huggingface. 35 | 36 | ```bash 37 | huggingface-cli download --resume-download mengtingwei/MagicPortrait --local-dir ./pre_trained 38 | ``` 39 | ### Step 2: Setup necessary libraries for face motion transfer 40 | 0. Put the downloaded `third_party_files` in the last step under the project directory `./`. 41 | 1. Visit [DECA Github](https://github.com/yfeng95/DECA?tab=readme-ov-file) to download the pretrained `deca_model.tar`. 42 | 2. Visit [FLAME website](https://flame.is.tue.mpg.de/download.php) to download `FLAME 2020` and extract `generic_model.pkl`. 43 | 3. Visit [FLAME website](https://flame.is.tue.mpg.de/download.php) to download `FLAME texture space` and extract `FLAME_texture.npz`. 44 | 4. Visit [DECA' data page](https://github.com/yfeng95/DECA/tree/master/data) and download all files. 45 | 5. Visit [SMIRK website](https://github.com/georgeretsi/smirk) to download `SMIRK_em1.pt`. 46 | 6. Place the files in their corresponding locations as specified below. 47 | 48 | ```plaintext 49 | decalib 50 | data/ 51 | deca_model.tar 52 | generic_model.pkl 53 | FLAME_texture.npz 54 | fixed_displacement_256.npy 55 | head_template.obj 56 | landmark_embedding.npy 57 | mean_texture.jpg 58 | texture_data_256.npy 59 | uv_face_eye_mask.png 60 | uv_face_mask.png 61 | ... 62 | smirk/ 63 | pretrained_models/ 64 | SMIRK_em1.pt 65 | ... 66 | ... 67 | ``` 68 | 69 | ### Step 3: Process the identity image and driving video 70 | 71 | > As our model is designed to focus only on the face, 72 | > you should crop the face from your images or videos if they are full-body shots. 73 | > However, **if your images or videos already contain only the face and the aspect ratio is approximately 1:1, 74 | > you can simply resize them into resolution of 512 $\times$ 512 without doing the following crop (1 and 2) steps.** 75 | 76 | 1. Crop the face from an image: 77 | 78 | ```python 79 | python crop_process.py --sign image --img_path './assets/boy.jpeg' --save_path './assets/boy_cropped.jpg' 80 | ``` 81 | 82 | 2. Crop the faces sequence from the driving video. 83 | 84 | * If you have a video 85 | ```bash 86 | mkdir ./assets/driving_images 87 | ffmpeg -i ./assets/driving.mp4 ./assets/driving_images/frame_%04d.jpg 88 | ``` 89 | Crop face from the driving images. 90 | ```python 91 | python crop_process.py --sign video --video_path './assets/driving_images' --video_imgs_dir './assets/driving_images_cropped' 92 | ``` 93 | 94 | 3. Retrieve guidance images using DECA and SMIRK models. 95 | 96 | ```python 97 | python render_and_transfer.py --sor_img './assets/boy_cropped.jpg' --driving_path './assets/driving_images_cropped' --save_name example1 98 | ``` 99 | The guidance will be saved in the `./transfers` directory. 100 | 101 | ### Step 4: Inference 102 | 103 | Update the model and image directories in `./configs/inference/inference.yaml` to match your own file locations. 104 | 105 | Then run: 106 | ```python 107 | python inference.py 108 | ``` 109 | 110 | ## Acknowledgement 111 | 112 | Our work is made possible thanks to open-source pioneering 113 | 3D face reconstruction works (including [DECA](https://github.com/yfeng95/DECA?tab=readme-ov-file) 114 | and [SMIRK](https://github.com/georgeretsi/smirk)) and 115 | a high-quality talking-video dataset [CelebV-HQ](https://celebv-hq.github.io). 116 | 117 | ## Contact 118 | Open an issue here or email [mengting.wei@oulu.fi](). -------------------------------------------------------------------------------- /assets/boy.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/assets/boy.jpeg -------------------------------------------------------------------------------- /assets/boy_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/assets/boy_cropped.jpg -------------------------------------------------------------------------------- /assets/combined_with_transitions_24.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/assets/combined_with_transitions_24.mp4 -------------------------------------------------------------------------------- /assets/driving.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/assets/driving.mp4 -------------------------------------------------------------------------------- /configs/inference/inference.yaml: -------------------------------------------------------------------------------- 1 | exp_name: Animation 2 | width: 512 3 | height: 512 4 | data: 5 | ref_image_path: '.../assets/girl1_cropped.jpg' # reference image path 6 | guidance_data_folder: '.../transfers/example1' # corresponding motion sequence folder 7 | frame_range: [0, 100] # [Optional] specify a frame range: [min_frame_idx, max_frame_idx] to select a clip from a motion sequence 8 | seed: 42 9 | 10 | base_model_path: '.../stable-diffusion-v1-5' 11 | vae_model_path: '.../sd-vae-ft-mse' 12 | 13 | 14 | ckpt_dir: '.../ckpts/' 15 | motion_module_path: '.../ckpts/motion_module-47360.pth' 16 | 17 | num_inference_steps: 20 18 | guidance_scale: 3.5 19 | enable_zero_snr: true 20 | weight_dtype: "fp16" 21 | 22 | guidance_types: 23 | - 'depth' 24 | - 'normal' 25 | - 'render' 26 | 27 | noise_scheduler_kwargs: 28 | num_train_timesteps: 1000 29 | beta_start: 0.00085 30 | beta_end: 0.012 31 | beta_schedule: "linear" 32 | steps_offset: 1 33 | clip_sample: false 34 | 35 | unet_additional_kwargs: 36 | use_inflated_groupnorm: true 37 | unet_use_cross_frame_attention: false 38 | unet_use_temporal_attention: false 39 | use_motion_module: true 40 | motion_module_resolutions: 41 | - 1 42 | - 2 43 | - 4 44 | - 8 45 | motion_module_mid_block: true 46 | motion_module_decoder_only: false 47 | motion_module_type: Vanilla 48 | motion_module_kwargs: 49 | num_attention_heads: 8 50 | num_transformer_block: 1 51 | attention_block_types: 52 | - Temporal_Self 53 | - Temporal_Self 54 | temporal_position_encoding: true 55 | temporal_position_encoding_max_len: 32 56 | temporal_attention_dim_div: 1 57 | 58 | guidance_encoder_kwargs: 59 | guidance_embedding_channels: 320 60 | guidance_input_channels: 3 61 | block_out_channels: [16, 32, 96, 256] 62 | 63 | enable_xformers_memory_efficient_attention: false 64 | -------------------------------------------------------------------------------- /crop_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cv2 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from insightface.app import FaceAnalysis 8 | from torchvision.utils import save_image 9 | 10 | import data_utils.datasets_faceswap as datasets_faceswap 11 | 12 | pil2tensor = transforms.Compose([transforms.ToTensor(), transforms.Resize(512)]) 13 | 14 | pil2tensor = transforms.ToTensor() 15 | 16 | app = FaceAnalysis(name='antelopev2', root=os.path.join('./', 'third_party_files'), 17 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 18 | app.prepare(ctx_id=0, det_size=(640, 640)) 19 | 20 | 21 | def get_bbox(dets, crop_ratio): 22 | if crop_ratio > 0: 23 | bbox = dets[0:4] 24 | bbox_size = max(bbox[2] - bbox[0], bbox[2] - bbox[0]) 25 | bbox_x = 0.5 * (bbox[2] + bbox[0]) 26 | bbox_y = 0.5 * (bbox[3] + bbox[1]) 27 | x1 = bbox_x - bbox_size * crop_ratio 28 | x2 = bbox_x + bbox_size * crop_ratio 29 | y1 = bbox_y - bbox_size * crop_ratio 30 | y2 = bbox_y + bbox_size * crop_ratio 31 | bbox_pts4 = np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32) 32 | else: 33 | # original box 34 | bbox = dets[0:4].reshape((2, 2)) 35 | bbox_pts4 = datasets_faceswap.get_box_lm4p(bbox) 36 | return bbox_pts4 37 | 38 | 39 | 40 | def crop_one_image(args): 41 | cur_img_sor_path = args.img_path 42 | im_pil_sor = Image.open(cur_img_sor_path).convert("RGB") 43 | face_info_sor = app.get(cv2.cvtColor(np.array(im_pil_sor), cv2.COLOR_RGB2BGR)) 44 | assert len(face_info_sor) >= 1, 'The input image must contain a face!' 45 | if len(face_info_sor) > 1: 46 | print('The input image contain more than one face, we will only use the maximum face') 47 | face_info_sor = \ 48 | sorted(face_info_sor, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[-1] 49 | dets_sor= face_info_sor['bbox'] 50 | 51 | bbox_pst_sor = get_bbox(dets_sor, crop_ratio=0.75) 52 | 53 | warp_mat_crop_sor = datasets_faceswap.transformation_from_points(bbox_pst_sor, 54 | datasets_faceswap.mean_box_lm4p_512) 55 | im_crop512_sor = cv2.warpAffine(np.array(im_pil_sor), warp_mat_crop_sor, (512, 512), flags=cv2.INTER_LINEAR) 56 | 57 | im_pil_sor = Image.fromarray(im_crop512_sor) 58 | im_pil_sor = pil2tensor(im_pil_sor) 59 | save_image(im_pil_sor, args.save_path) 60 | 61 | 62 | def crop_a_directory(args): 63 | os.makedirs(args.video_imgs_dir, exist_ok=True) 64 | imgs = sorted(os.listdir(args.video_path)) 65 | first_img_path = os.path.join(args.video_path, imgs[0]) 66 | im_pil_sor = Image.open(first_img_path).convert("RGB") 67 | face_info_sor = app.get(cv2.cvtColor(np.array(im_pil_sor), cv2.COLOR_RGB2BGR)) 68 | assert len(face_info_sor) >= 1, 'The input image must contain a face!' 69 | if len(face_info_sor) > 1: 70 | print('The input image contain more than one face, we will only use the maximum face') 71 | face_info_sor = \ 72 | sorted(face_info_sor, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[-1] 73 | dets_sor = face_info_sor['bbox'] 74 | for img in imgs: 75 | cur_img_path = os.path.join(args.video_path, img) 76 | im_pil_sor = Image.open(cur_img_path).convert("RGB") 77 | bbox_pst_sor = get_bbox(dets_sor, crop_ratio=0.75) 78 | 79 | warp_mat_crop_sor = datasets_faceswap.transformation_from_points(bbox_pst_sor, 80 | datasets_faceswap.mean_box_lm4p_512) 81 | im_crop512_sor = cv2.warpAffine(np.array(im_pil_sor), warp_mat_crop_sor, (512, 512), flags=cv2.INTER_LINEAR) 82 | 83 | im_pil_sor = Image.fromarray(im_crop512_sor) 84 | im_pil_sor = pil2tensor(im_pil_sor) 85 | save_image(im_pil_sor, os.path.join(args.video_imgs_dir, img)) 86 | 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument( 91 | "--sign", 92 | type=str, 93 | default='image', 94 | required=False 95 | ) 96 | # ********************************* identity image *************************************** 97 | parser.add_argument( 98 | "--img_path", 99 | type=str, 100 | default='/home/mengting/projects/process_scripts/test_images/boy2.jpeg', 101 | required=False 102 | ) 103 | parser.add_argument( 104 | "--save_path", 105 | type=str, 106 | default='/home/mengting/projects/process_scripts/test_images/boy2_cropped.jpg', 107 | required=False 108 | ) 109 | # ********************************** driving video ************************************** 110 | parser.add_argument( 111 | "--video_path", 112 | type=str, 113 | default='/home/mengting/projects/process_scripts/test_images/target_images1', 114 | required=False 115 | ) 116 | parser.add_argument( 117 | "--video_imgs_dir", 118 | type=str, 119 | default='/home/mengting/projects/process_scripts/test_images/target_images1_cropped', 120 | required=False 121 | ) 122 | args = parser.parse_args() 123 | if args.sign == 'image': 124 | crop_one_image(args) 125 | elif args.sign == 'video': 126 | crop_a_directory(args) 127 | else: 128 | assert 'sign is invalid' -------------------------------------------------------------------------------- /data_utils/__pycache__/datasets_faceswap.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/data_utils/__pycache__/datasets_faceswap.cpython-310.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/transfer_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/data_utils/__pycache__/transfer_utils.cpython-310.pyc -------------------------------------------------------------------------------- /data_utils/datasets_faceswap.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import cv2 4 | from PIL import ImageFile 5 | 6 | 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | mean_face_lm5p_256 = np.array([ 10 | [(30.2946+8)*2+16, 51.6963*2], # left eye pupil 11 | [(65.5318+8)*2+16, 51.5014*2], # right eye pupil 12 | [(48.0252+8)*2+16, 71.7366*2], # nose tip 13 | [(33.5493+8)*2+16, 92.3655*2], # left mouth corner 14 | [(62.7299+8)*2+16, 92.2041*2], # right mouth corner 15 | ], dtype=np.float32) 16 | 17 | 18 | 19 | mean_box_lm4p_512 = np.array([ 20 | [80, 80], 21 | [80, 432], 22 | [432, 432], 23 | [432, 80], 24 | ], dtype=np.float32) 25 | 26 | 27 | 28 | def get_box_lm4p(pts): 29 | x1 = np.min(pts[:,0]) 30 | x2 = np.max(pts[:,0]) 31 | y1 = np.min(pts[:,1]) 32 | y2 = np.max(pts[:,1]) 33 | 34 | x_center = (x1+x2)*0.5 35 | y_center = (y1+y2)*0.5 36 | box_size = max(x2-x1, y2-y1) 37 | 38 | x1 = x_center-0.5*box_size 39 | x2 = x_center+0.5*box_size 40 | y1 = y_center-0.5*box_size 41 | y2 = y_center+0.5*box_size 42 | 43 | return np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32) 44 | 45 | 46 | def get_affine_transform(target_face_lm5p, mean_lm5p): 47 | mat_warp = np.zeros((2,3)) 48 | A = np.zeros((4,4)) 49 | B = np.zeros((4)) 50 | for i in range(5): 51 | #sa[0][0] += a[i].x*a[i].x + a[i].y*a[i].y; 52 | A[0][0] += target_face_lm5p[i][0] * target_face_lm5p[i][0] + target_face_lm5p[i][1] * target_face_lm5p[i][1] 53 | #sa[0][2] += a[i].x; 54 | A[0][2] += target_face_lm5p[i][0] 55 | #sa[0][3] += a[i].y; 56 | A[0][3] += target_face_lm5p[i][1] 57 | 58 | #sb[0] += a[i].x*b[i].x + a[i].y*b[i].y; 59 | B[0] += target_face_lm5p[i][0] * mean_lm5p[i][0] + target_face_lm5p[i][1] * mean_lm5p[i][1] 60 | #sb[1] += a[i].x*b[i].y - a[i].y*b[i].x; 61 | B[1] += target_face_lm5p[i][0] * mean_lm5p[i][1] - target_face_lm5p[i][1] * mean_lm5p[i][0] 62 | #sb[2] += b[i].x; 63 | B[2] += mean_lm5p[i][0] 64 | #sb[3] += b[i].y; 65 | B[3] += mean_lm5p[i][1] 66 | 67 | #sa[1][1] = sa[0][0]; 68 | A[1][1] = A[0][0] 69 | #sa[2][1] = sa[1][2] = -sa[0][3]; 70 | A[2][1] = A[1][2] = -A[0][3] 71 | #sa[3][1] = sa[1][3] = sa[2][0] = sa[0][2]; 72 | A[3][1] = A[1][3] = A[2][0] = A[0][2] 73 | #sa[2][2] = sa[3][3] = count; 74 | A[2][2] = A[3][3] = 5 75 | #sa[3][0] = sa[0][3]; 76 | A[3][0] = A[0][3] 77 | 78 | _, mat23 = cv2.solve(A, B, flags=cv2.DECOMP_SVD) 79 | mat_warp[0][0] = mat23[0] 80 | mat_warp[1][1] = mat23[0] 81 | mat_warp[0][1] = -mat23[1] 82 | mat_warp[1][0] = mat23[1] 83 | mat_warp[0][2] = mat23[2] 84 | mat_warp[1][2] = mat23[3] 85 | 86 | return mat_warp 87 | 88 | 89 | 90 | 91 | def transformation_from_points(points1, points2): 92 | points1 = np.float64(np.matrix([[point[0], point[1]] for point in points1])) 93 | points2 = np.float64(np.matrix([[point[0], point[1]] for point in points2])) 94 | 95 | points1 = points1.astype(np.float64) 96 | points2 = points2.astype(np.float64) 97 | c1 = np.mean(points1, axis=0) 98 | c2 = np.mean(points2, axis=0) 99 | points1 -= c1 100 | points2 -= c2 101 | s1 = np.std(points1) 102 | s2 = np.std(points2) 103 | points1 /= s1 104 | points2 /= s2 105 | #points2 = np.array(points2) 106 | #write_pts('pt2.txt', points2) 107 | U, S, Vt = np.linalg.svd(points1.T * points2) 108 | R = (U * Vt).T 109 | return np.array(np.vstack([np.hstack(((s2 / s1) * R,c2.T - (s2 / s1) * R * c1.T)),np.matrix([0., 0., 1.])])[:2]) 110 | 111 | -------------------------------------------------------------------------------- /data_utils/transfer_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | import cv2 8 | import scipy 9 | from skimage.io import imread, imsave 10 | from skimage.transform import estimate_transform, warp, resize, rescale 11 | from glob import glob 12 | import scipy.io 13 | from decalib.datasets import datasets 14 | from torchvision.utils import save_image 15 | from decalib.datasets import detectors 16 | import shutil 17 | 18 | 19 | face_detector = detectors.FAN() 20 | scale = 1.3 21 | resolution_inp = 224 22 | 23 | 24 | 25 | def bbox2point(left, right, top, bottom, type='bbox'): 26 | 27 | if type =='kpt68': 28 | old_size = (right - left + bottom - top) / 2 * 1.1 29 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ]) 30 | elif type =='bbox': 31 | old_size = (right - left + bottom - top ) /2 32 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size *0.12]) 33 | else: 34 | raise NotImplementedError 35 | return old_size, center 36 | 37 | 38 | 39 | 40 | 41 | def get_image_dict(img_path, size, iscrop): 42 | img_name = img_path.split('/')[-1] 43 | im = imread(img_path) 44 | if size is not None: # size = 256 45 | im = (resize(im, (size, size), anti_aliasing=True) * 255.).astype(np.uint8) 46 | # (256, 256, 3) 47 | image = np.array(im) 48 | if len(image.shape) == 2: 49 | image = image[:, :, None].repeat(1, 1, 3) 50 | if len(image.shape) == 3 and image.shape[2] > 3: 51 | image = image[:, :, :3] 52 | 53 | h, w, _ = image.shape 54 | if iscrop: # true 55 | # provide kpt as txt file, or mat file (for AFLW2000) 56 | kpt_matpath = os.path.splitext(img_path)[0] + '.mat' 57 | kpt_txtpath = os.path.splitext(img_path)[0] + '.txt' 58 | if os.path.exists(kpt_matpath): 59 | kpt = scipy.io.loadmat(kpt_matpath)['pt3d_68'].T 60 | left = np.min(kpt[:, 0]) 61 | right = np.max(kpt[:, 0]) 62 | top = np.min(kpt[:, 1]) 63 | bottom = np.max(kpt[:, 1]) 64 | old_size, center = bbox2point(left, right, top, bottom, type='kpt68') 65 | elif os.path.exists(kpt_txtpath): 66 | kpt = np.loadtxt(kpt_txtpath) 67 | left = np.min(kpt[:, 0]) 68 | right = np.max(kpt[:, 0]) 69 | top = np.min(kpt[:, 1]) 70 | bottom = np.max(kpt[:, 1]) 71 | old_size, center = bbox2point(left, right, top, bottom, type='kpt68') 72 | else: 73 | bbox, bbox_type = face_detector.run(image) 74 | if len(bbox) < 4: 75 | print('no face detected! run original image') 76 | left = 0 77 | right = h - 1 78 | top = 0 79 | bottom = w - 1 80 | else: 81 | left = bbox[0] 82 | right = bbox[2] 83 | top = bbox[1] 84 | bottom = bbox[3] 85 | old_size, center = bbox2point(left, right, top, bottom, type=bbox_type) 86 | size = int(old_size * scale) 87 | src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2], 88 | [center[0] + size / 2, center[1] - size / 2]]) 89 | else: 90 | src_pts = np.array([[0, 0], [0, h - 1], [w - 1, 0]]) 91 | # DST_PTS = np.array([[0, 0], [0, h-1], [w-1, 0]]) 92 | DST_PTS = np.array([[0, 0], [0, resolution_inp - 1], [resolution_inp - 1, 0]]) 93 | tform = estimate_transform('similarity', src_pts, DST_PTS) 94 | 95 | image = image / 255. 96 | 97 | dst_image = warp(image, tform.inverse, output_shape=(resolution_inp, resolution_inp)) 98 | dst_image = dst_image.transpose(2, 0, 1) 99 | return {'image': torch.tensor(dst_image).float(), 100 | 'imagename': img_name, 101 | 'tform': torch.tensor(tform.params).float(), 102 | 'original_image': torch.tensor(image.transpose(2, 0, 1)).float(), 103 | } 104 | 105 | 106 | 107 | def check_face(img_path, size): 108 | im = imread(img_path) 109 | if size is not None: # size = 256 110 | im = (resize(im, (size, size), anti_aliasing=True) * 255.).astype(np.uint8) 111 | # (256, 256, 3) 112 | image = np.array(im) 113 | bbox, bbox_type = face_detector.run(image) 114 | if len(bbox) < 4: 115 | return False 116 | return True 117 | 118 | -------------------------------------------------------------------------------- /decalib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/__init__.py -------------------------------------------------------------------------------- /decalib/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/__pycache__/deca_with_smirk.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/__pycache__/deca_with_smirk.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/datasets/__pycache__/datasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/datasets/__pycache__/datasets.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/datasets/__pycache__/detectors.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/datasets/__pycache__/detectors.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/datasets/aflw2000.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | import scipy.io 12 | 13 | class AFLW2000(Dataset): 14 | def __init__(self, testpath='/ps/scratch/yfeng/Data/AFLW2000/GT', crop_size=224): 15 | ''' 16 | data class for loading AFLW2000 dataset 17 | make sure each image has corresponding mat file, which provides cropping infromation 18 | ''' 19 | if os.path.isdir(testpath): 20 | self.imagepath_list = glob(testpath + '/*.jpg') + glob(testpath + '/*.png') 21 | elif isinstance(testpath, list): 22 | self.imagepath_list = testpath 23 | elif os.path.isfile(testpath) and (testpath[-3:] in ['jpg', 'png']): 24 | self.imagepath_list = [testpath] 25 | else: 26 | print('please check the input path') 27 | exit() 28 | print('total {} images'.format(len(self.imagepath_list))) 29 | self.imagepath_list = sorted(self.imagepath_list) 30 | self.crop_size = crop_size 31 | self.scale = 1.6 32 | self.resolution_inp = crop_size 33 | 34 | def __len__(self): 35 | return len(self.imagepath_list) 36 | 37 | def __getitem__(self, index): 38 | imagepath = self.imagepath_list[index] 39 | imagename = imagepath.split('/')[-1].split('.')[0] 40 | image = imread(imagepath)[:,:,:3] 41 | kpt = scipy.io.loadmat(imagepath.replace('jpg', 'mat'))['pt3d_68'].T 42 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 43 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 44 | 45 | h, w, _ = image.shape 46 | old_size = (right - left + bottom - top)/2 47 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1]) 48 | size = int(old_size*self.scale) 49 | 50 | # crop image 51 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 52 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]]) 53 | tform = estimate_transform('similarity', src_pts, DST_PTS) 54 | 55 | image = image/255. 56 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp)) 57 | dst_image = dst_image.transpose(2,0,1) 58 | return {'image': torch.tensor(dst_image).float(), 59 | 'imagename': imagename, 60 | # 'tform': tform, 61 | # 'original_image': torch.tensor(image.transpose(2,0,1)).float(), 62 | } -------------------------------------------------------------------------------- /decalib/datasets/build_datasets.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | from torch.utils.data import Dataset, ConcatDataset 4 | import torchvision.transforms as transforms 5 | import numpy as np 6 | import cv2 7 | import scipy 8 | from skimage.io import imread, imsave 9 | from skimage.transform import estimate_transform, warp, resize, rescale 10 | from glob import glob 11 | 12 | from .vggface import VGGFace2Dataset 13 | from .ethnicity import EthnicityDataset 14 | from .aflw2000 import AFLW2000 15 | from .now import NoWDataset 16 | from .vox import VoxelDataset 17 | 18 | def build_train(config, is_train=True): 19 | data_list = [] 20 | if 'vox2' in config.training_data: 21 | data_list.append(VoxelDataset(dataname='vox2', K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 22 | if 'vggface2' in config.training_data: 23 | data_list.append(VGGFace2Dataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 24 | if 'vggface2hq' in config.training_data: 25 | data_list.append(VGGFace2HQDataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 26 | if 'ethnicity' in config.training_data: 27 | data_list.append(EthnicityDataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 28 | if 'coco' in config.training_data: 29 | data_list.append(COCODataset(image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale)) 30 | if 'celebahq' in config.training_data: 31 | data_list.append(CelebAHQDataset(image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale)) 32 | dataset = ConcatDataset(data_list) 33 | 34 | return dataset 35 | 36 | def build_val(config, is_train=True): 37 | data_list = [] 38 | if 'vggface2' in config.eval_data: 39 | data_list.append(VGGFace2Dataset(isEval=True, K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle)) 40 | if 'now' in config.eval_data: 41 | data_list.append(NoWDataset()) 42 | if 'aflw2000' in config.eval_data: 43 | data_list.append(AFLW2000()) 44 | dataset = ConcatDataset(data_list) 45 | 46 | return dataset 47 | -------------------------------------------------------------------------------- /decalib/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import os, sys 17 | import torch 18 | from torch.utils.data import Dataset, DataLoader 19 | import torchvision.transforms as transforms 20 | import numpy as np 21 | import cv2 22 | import scipy 23 | from skimage.io import imread, imsave 24 | from skimage.transform import estimate_transform, warp, resize, rescale 25 | from glob import glob 26 | import scipy.io 27 | from decalib.datasets import datasets 28 | from torchvision.utils import save_image 29 | 30 | from . import detectors 31 | 32 | def video2sequence(video_path, sample_step=10): 33 | videofolder = os.path.splitext(video_path)[0] 34 | os.makedirs(videofolder, exist_ok=True) 35 | video_name = os.path.splitext(os.path.split(video_path)[-1])[0] 36 | vidcap = cv2.VideoCapture(video_path) 37 | success,image = vidcap.read() 38 | count = 0 39 | imagepath_list = [] 40 | while success: 41 | # if count%sample_step == 0: 42 | imagepath = os.path.join(videofolder, f'{video_name}_frame{count:04d}.jpg') 43 | cv2.imwrite(imagepath, image) # save frame as JPEG file 44 | success,image = vidcap.read() 45 | count += 1 46 | imagepath_list.append(imagepath) 47 | print('video frames are stored in {}'.format(videofolder)) 48 | return imagepath_list 49 | 50 | class TestData(Dataset): 51 | # 传递过来的有iscrop=True, size=256, sort=True 52 | def __init__(self, testpath, iscrop=True, crop_size=224, scale=1.25, face_detector='fan', 53 | sample_step=10, size=256, sort=False): 54 | 55 | if isinstance(testpath, list): 56 | self.imagepath_list = testpath 57 | elif os.path.isdir(testpath): 58 | self.imagepath_list = glob(testpath + '/*.jpg') + glob(testpath + '/*.png') + glob(testpath + '/*.bmp') 59 | elif os.path.isfile(testpath) and (testpath[-3:] in ['jpg', 'png', 'bmp']): 60 | self.imagepath_list = [testpath] 61 | elif os.path.isfile(testpath) and (testpath[-3:] in ['mp4', 'csv', 'vid', 'ebm']): 62 | self.imagepath_list = video2sequence(testpath, sample_step) 63 | 64 | if sort: 65 | self.imagepath_list = sorted(self.imagepath_list) 66 | self.crop_size = crop_size 67 | self.scale = scale 68 | self.iscrop = iscrop 69 | self.resolution_inp = crop_size 70 | self.size = size 71 | # 使用的是face alignment的关键点检测工具 72 | if face_detector == 'fan': 73 | self.face_detector = detectors.FAN() 74 | else: 75 | print(f'please check the detector: {face_detector}') 76 | exit() 77 | 78 | def __len__(self): 79 | return len(self.imagepath_list) 80 | 81 | def bbox2point(self, left, right, top, bottom, type='bbox'): 82 | 83 | if type=='kpt68': 84 | old_size = (right - left + bottom - top) / 2 * 1.1 85 | # 人脸中心点的位置 86 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ]) 87 | elif type=='bbox': 88 | old_size = (right - left + bottom - top)/2 89 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.12]) 90 | else: 91 | raise NotImplementedError 92 | return old_size, center 93 | 94 | def get_image(self, image): 95 | h, w, _ = image.shape 96 | bbox, bbox_type = self.face_detector.run(image) 97 | if len(bbox) < 4: 98 | print('no face detected! run original image') 99 | left = 0; right = h-1; top=0; bottom=w-1 100 | else: 101 | left = bbox[0]; right=bbox[2] 102 | top = bbox[1]; bottom=bbox[3] 103 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type) 104 | size = int(old_size*self.scale) 105 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 106 | 107 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]]) 108 | tform = estimate_transform('similarity', src_pts, DST_PTS) 109 | 110 | image = image / 255. 111 | 112 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp)) 113 | dst_image = dst_image.transpose(2,0,1) 114 | return {'image': torch.tensor(dst_image).float(), 115 | 'tform': torch.tensor(tform.params).float(), 116 | 'original_image': torch.tensor(image.transpose(2,0,1)).float(), 117 | } 118 | 119 | 120 | def __getitem__(self, index): 121 | 122 | imagepath = self.imagepath_list[index] 123 | imagename = os.path.splitext(os.path.split(imagepath)[-1])[0] 124 | im = imread(imagepath) 125 | 126 | if self.size is not None: # size = 256 127 | im = (resize(im, (self.size, self.size), anti_aliasing=True) * 255.).astype(np.uint8) 128 | 129 | # (256, 256, 3) 130 | image = np.array(im) 131 | 132 | if len(image.shape) == 2: 133 | image = image[:, :, None].repeat(1,1,3) 134 | if len(image.shape) == 3 and image.shape[2] > 3: 135 | image = image[:, :, :3] 136 | 137 | h, w, _ = image.shape 138 | if self.iscrop: # true 139 | # provide kpt as txt file, or mat file (for AFLW2000) 140 | # 检查是否存在landmark的文件,不存在则自己检测 141 | kpt_matpath = os.path.splitext(imagepath)[0]+'.mat' 142 | kpt_txtpath = os.path.splitext(imagepath)[0]+'.txt' 143 | if os.path.exists(kpt_matpath): 144 | kpt = scipy.io.loadmat(kpt_matpath)['pt3d_68'].T 145 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]) 146 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 147 | old_size, center = self.bbox2point(left, right, top, bottom, type='kpt68') 148 | elif os.path.exists(kpt_txtpath): 149 | kpt = np.loadtxt(kpt_txtpath) 150 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]) 151 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 152 | old_size, center = self.bbox2point(left, right, top, bottom, type='kpt68') 153 | else: 154 | bbox, bbox_type = self.face_detector.run(image) 155 | if len(bbox) < 4: 156 | print('no face detected! run original image') 157 | left = 0; right = h-1; top=0; bottom=w-1 158 | else: 159 | left = bbox[0]; right=bbox[2] 160 | top = bbox[1]; bottom=bbox[3] 161 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type) 162 | size = int(old_size * self.scale) 163 | src_pts = np.array([[center[0] - size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 164 | else: 165 | src_pts = np.array([[0, 0], [0, h-1], [w-1, 0]]) 166 | # DST_PTS = np.array([[0, 0], [0, h-1], [w-1, 0]]) 167 | # self.resolution_inp = 224,目标图像大小 168 | DST_PTS = np.array([[0, 0], [0, self.resolution_inp - 1], [self.resolution_inp - 1, 0]]) 169 | # 计算源图像变换到目标图像需要经过怎样的矩阵转换 170 | tform = estimate_transform('similarity', src_pts, DST_PTS) 171 | 172 | image = image / 255. # 0-1区间 173 | 174 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp)) 175 | dst_image = dst_image.transpose(2,0,1) 176 | 177 | return {'image': torch.tensor(dst_image).float(), # 只对面部区域进行了保留 178 | 'imagename': imagename, 179 | 'tform': torch.tensor(tform.params).float(), 180 | 'original_image': torch.tensor(image.transpose(2, 0, 1)).float(), 181 | } 182 | 183 | 184 | 185 | 186 | 187 | if __name__ == '__main__': 188 | testdata_source = datasets.TestData( 189 | source, iscrop=True, size=512, sort=True 190 | ) -------------------------------------------------------------------------------- /decalib/datasets/detectors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import numpy as np 17 | import torch 18 | 19 | class FAN(object): 20 | def __init__(self): 21 | import face_alignment 22 | self.model = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) 23 | 24 | def run(self, image): 25 | ''' 26 | image: 0-255, uint8, rgb, [h, w, 3] 27 | return: detected box list 28 | ''' 29 | out = self.model.get_landmarks(image) 30 | if out is None: 31 | return [0], 'kpt68' 32 | else: 33 | kpt = out[0].squeeze() 34 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]) 35 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 36 | bbox = [left, top, right, bottom] 37 | return bbox, 'kpt68' 38 | 39 | class MTCNN(object): 40 | def __init__(self, device = 'cpu'): 41 | ''' 42 | https://github.com/timesler/facenet-pytorch/blob/master/examples/infer.ipynb 43 | ''' 44 | from facenet_pytorch import MTCNN as mtcnn 45 | self.device = device 46 | self.model = mtcnn(keep_all=True) 47 | def run(self, input): 48 | ''' 49 | image: 0-255, uint8, rgb, [h, w, 3] 50 | return: detected box 51 | ''' 52 | out = self.model.detect(input[None,...]) 53 | if out[0][0] is None: 54 | return [0] 55 | else: 56 | bbox = out[0][0].squeeze() 57 | return bbox, 'bbox' 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /decalib/datasets/ethnicity.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | 12 | class EthnicityDataset(Dataset): 13 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False): 14 | ''' 15 | K must be less than 6 16 | ''' 17 | self.K = K 18 | self.image_size = image_size 19 | self.imagefolder = '/ps/scratch/face2d3d/train' 20 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7/' 21 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch/' 22 | # hq: 23 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy' 24 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_and_race_per_7000_african_asian_2d_train_list_max_normal_100_ring_5_1_serial.npy' 25 | self.data_lines = np.load(datafile).astype('str') 26 | 27 | self.isTemporal = isTemporal 28 | self.scale = scale #[scale_min, scale_max] 29 | self.trans_scale = trans_scale #[dx, dy] 30 | self.isSingle = isSingle 31 | if isSingle: 32 | self.K = 1 33 | 34 | def __len__(self): 35 | return len(self.data_lines) 36 | 37 | def __getitem__(self, idx): 38 | images_list = []; kpt_list = []; mask_list = [] 39 | for i in range(self.K): 40 | name = self.data_lines[idx, i] 41 | if name[0]=='n': 42 | self.imagefolder = '/ps/scratch/face2d3d/train/' 43 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7/' 44 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch/' 45 | elif name[0]=='A': 46 | self.imagefolder = '/ps/scratch/face2d3d/race_per_7000/' 47 | self.kptfolder = '/ps/scratch/face2d3d/race_per_7000_annotated_torch7_new/' 48 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/race7000_seg/test_crop_size_400_batch/' 49 | 50 | image_path = os.path.join(self.imagefolder, name + '.jpg') 51 | seg_path = os.path.join(self.segfolder, name + '.npy') 52 | kpt_path = os.path.join(self.kptfolder, name + '.npy') 53 | 54 | image = imread(image_path)/255. 55 | kpt = np.load(kpt_path)[:,:2] 56 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1]) 57 | 58 | ### crop information 59 | tform = self.crop(image, kpt) 60 | ## crop 61 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 62 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size)) 63 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 64 | 65 | # normalized kpt 66 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1 67 | 68 | images_list.append(cropped_image.transpose(2,0,1)) 69 | kpt_list.append(cropped_kpt) 70 | mask_list.append(cropped_mask) 71 | 72 | ### 73 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3 74 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3 75 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3 76 | 77 | if self.isSingle: 78 | images_array = images_array.squeeze() 79 | kpt_array = kpt_array.squeeze() 80 | mask_array = mask_array.squeeze() 81 | 82 | data_dict = { 83 | 'image': images_array, 84 | 'landmark': kpt_array, 85 | 'mask': mask_array 86 | } 87 | 88 | return data_dict 89 | 90 | def crop(self, image, kpt): 91 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 92 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 93 | 94 | h, w, _ = image.shape 95 | old_size = (right - left + bottom - top)/2 96 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1]) 97 | # translate center 98 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale 99 | center = center + trans_scale*old_size # 0.5 100 | 101 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0] 102 | size = int(old_size*scale) 103 | 104 | # crop image 105 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 106 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]]) 107 | tform = estimate_transform('similarity', src_pts, DST_PTS) 108 | 109 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 110 | # # change kpt accordingly 111 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 112 | return tform 113 | 114 | def load_mask(self, maskpath, h, w): 115 | # print(maskpath) 116 | if os.path.isfile(maskpath): 117 | vis_parsing_anno = np.load(maskpath) 118 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 119 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 120 | mask = np.zeros_like(vis_parsing_anno) 121 | # for i in range(1, 16): 122 | mask[vis_parsing_anno>0.5] = 1. 123 | else: 124 | mask = np.ones((h, w)) 125 | return mask 126 | 127 | -------------------------------------------------------------------------------- /decalib/datasets/now.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | 12 | class NoWDataset(Dataset): 13 | def __init__(self, ring_elements=6, crop_size=224, scale=1.6): 14 | folder = '/ps/scratch/yfeng/other-github/now_evaluation/data/NoW_Dataset' 15 | self.data_path = os.path.join(folder, 'imagepathsvalidation.txt') 16 | with open(self.data_path) as f: 17 | self.data_lines = f.readlines() 18 | 19 | self.imagefolder = os.path.join(folder, 'final_release_version', 'iphone_pictures') 20 | self.bbxfolder = os.path.join(folder, 'final_release_version', 'detected_face') 21 | 22 | # self.data_path = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/test_image_paths_ring_6_elements.npy' 23 | # self.imagepath = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/iphone_pictures/' 24 | # self.bbxpath = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/detected_face/' 25 | self.crop_size = crop_size 26 | self.scale = scale 27 | 28 | def __len__(self): 29 | return len(self.data_lines) 30 | 31 | def __getitem__(self, index): 32 | imagepath = os.path.join(self.imagefolder, self.data_lines[index].strip()) #+ '.jpg' 33 | bbx_path = os.path.join(self.bbxfolder, self.data_lines[index].strip().replace('.jpg', '.npy')) 34 | bbx_data = np.load(bbx_path, allow_pickle=True, encoding='latin1').item() 35 | # box = np.array([[bbx_data['left'], bbx_data['top']], [bbx_data['right'], bbx_data['bottom']]]).astype('float32') 36 | left = bbx_data['left']; right = bbx_data['right'] 37 | top = bbx_data['top']; bottom = bbx_data['bottom'] 38 | 39 | imagename = imagepath.split('/')[-1].split('.')[0] 40 | image = imread(imagepath)[:,:,:3] 41 | 42 | h, w, _ = image.shape 43 | old_size = (right - left + bottom - top)/2 44 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ]) 45 | size = int(old_size*self.scale) 46 | 47 | # crop image 48 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 49 | DST_PTS = np.array([[0,0], [0,self.crop_size - 1], [self.crop_size - 1, 0]]) 50 | tform = estimate_transform('similarity', src_pts, DST_PTS) 51 | 52 | image = image/255. 53 | dst_image = warp(image, tform.inverse, output_shape=(self.crop_size, self.crop_size)) 54 | dst_image = dst_image.transpose(2,0,1) 55 | return {'image': torch.tensor(dst_image).float(), 56 | 'imagename': self.data_lines[index].strip().replace('.jpg', ''), 57 | # 'tform': tform, 58 | # 'original_image': torch.tensor(image.transpose(2,0,1)).float(), 59 | } -------------------------------------------------------------------------------- /decalib/datasets/vggface.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | 12 | class VGGFace2Dataset(Dataset): 13 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False): 14 | ''' 15 | K must be less than 6 16 | ''' 17 | self.K = K 18 | self.image_size = image_size 19 | self.imagefolder = '/ps/scratch/face2d3d/train' 20 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7' 21 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch' 22 | # hq: 23 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy' 24 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_train_list_max_normal_100_ring_5_1_serial.npy' 25 | if isEval: 26 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_val_list_max_normal_100_ring_5_1_serial.npy' 27 | self.data_lines = np.load(datafile).astype('str') 28 | 29 | self.isTemporal = isTemporal 30 | self.scale = scale #[scale_min, scale_max] 31 | self.trans_scale = trans_scale #[dx, dy] 32 | self.isSingle = isSingle 33 | if isSingle: 34 | self.K = 1 35 | 36 | def __len__(self): 37 | return len(self.data_lines) 38 | 39 | def __getitem__(self, idx): 40 | images_list = []; kpt_list = []; mask_list = [] 41 | 42 | random_ind = np.random.permutation(5)[:self.K] 43 | for i in random_ind: 44 | name = self.data_lines[idx, i] 45 | image_path = os.path.join(self.imagefolder, name + '.jpg') 46 | seg_path = os.path.join(self.segfolder, name + '.npy') 47 | kpt_path = os.path.join(self.kptfolder, name + '.npy') 48 | 49 | image = imread(image_path)/255. 50 | kpt = np.load(kpt_path)[:,:2] 51 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1]) 52 | 53 | ### crop information 54 | tform = self.crop(image, kpt) 55 | ## crop 56 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 57 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size)) 58 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 59 | 60 | # normalized kpt 61 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1 62 | 63 | images_list.append(cropped_image.transpose(2,0,1)) 64 | kpt_list.append(cropped_kpt) 65 | mask_list.append(cropped_mask) 66 | 67 | ### 68 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3 69 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3 70 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3 71 | 72 | if self.isSingle: 73 | images_array = images_array.squeeze() 74 | kpt_array = kpt_array.squeeze() 75 | mask_array = mask_array.squeeze() 76 | 77 | data_dict = { 78 | 'image': images_array, 79 | 'landmark': kpt_array, 80 | 'mask': mask_array 81 | } 82 | 83 | return data_dict 84 | 85 | def crop(self, image, kpt): 86 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 87 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 88 | 89 | h, w, _ = image.shape 90 | old_size = (right - left + bottom - top)/2 91 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1]) 92 | # translate center 93 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale 94 | center = center + trans_scale*old_size # 0.5 95 | 96 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0] 97 | size = int(old_size*scale) 98 | 99 | # crop image 100 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 101 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]]) 102 | tform = estimate_transform('similarity', src_pts, DST_PTS) 103 | 104 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 105 | # # change kpt accordingly 106 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 107 | return tform 108 | 109 | def load_mask(self, maskpath, h, w): 110 | # print(maskpath) 111 | if os.path.isfile(maskpath): 112 | vis_parsing_anno = np.load(maskpath) 113 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 114 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 115 | mask = np.zeros_like(vis_parsing_anno) 116 | # for i in range(1, 16): 117 | mask[vis_parsing_anno>0.5] = 1. 118 | else: 119 | mask = np.ones((h, w)) 120 | return mask 121 | 122 | 123 | 124 | class VGGFace2HQDataset(Dataset): 125 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False): 126 | ''' 127 | K must be less than 6 128 | ''' 129 | self.K = K 130 | self.image_size = image_size 131 | self.imagefolder = '/ps/scratch/face2d3d/train' 132 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7' 133 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch' 134 | # hq: 135 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy' 136 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy' 137 | self.data_lines = np.load(datafile).astype('str') 138 | 139 | self.isTemporal = isTemporal 140 | self.scale = scale #[scale_min, scale_max] 141 | self.trans_scale = trans_scale #[dx, dy] 142 | self.isSingle = isSingle 143 | if isSingle: 144 | self.K = 1 145 | 146 | def __len__(self): 147 | return len(self.data_lines) 148 | 149 | def __getitem__(self, idx): 150 | images_list = []; kpt_list = []; mask_list = [] 151 | 152 | for i in range(self.K): 153 | name = self.data_lines[idx, i] 154 | image_path = os.path.join(self.imagefolder, name + '.jpg') 155 | seg_path = os.path.join(self.segfolder, name + '.npy') 156 | kpt_path = os.path.join(self.kptfolder, name + '.npy') 157 | 158 | image = imread(image_path)/255. 159 | kpt = np.load(kpt_path)[:,:2] 160 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1]) 161 | 162 | ### crop information 163 | tform = self.crop(image, kpt) 164 | ## crop 165 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 166 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size)) 167 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 168 | 169 | # normalized kpt 170 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1 171 | 172 | images_list.append(cropped_image.transpose(2,0,1)) 173 | kpt_list.append(cropped_kpt) 174 | mask_list.append(cropped_mask) 175 | 176 | ### 177 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3 178 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3 179 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3 180 | 181 | if self.isSingle: 182 | images_array = images_array.squeeze() 183 | kpt_array = kpt_array.squeeze() 184 | mask_array = mask_array.squeeze() 185 | 186 | data_dict = { 187 | 'image': images_array, 188 | 'landmark': kpt_array, 189 | 'mask': mask_array 190 | } 191 | 192 | return data_dict 193 | 194 | def crop(self, image, kpt): 195 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 196 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1]) 197 | 198 | h, w, _ = image.shape 199 | old_size = (right - left + bottom - top)/2 200 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1]) 201 | # translate center 202 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale 203 | center = center + trans_scale*old_size # 0.5 204 | 205 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0] 206 | size = int(old_size*scale) 207 | 208 | # crop image 209 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 210 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]]) 211 | tform = estimate_transform('similarity', src_pts, DST_PTS) 212 | 213 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size)) 214 | # # change kpt accordingly 215 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params) 216 | return tform 217 | 218 | def load_mask(self, maskpath, h, w): 219 | # print(maskpath) 220 | if os.path.isfile(maskpath): 221 | vis_parsing_anno = np.load(maskpath) 222 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 223 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 224 | mask = np.zeros_like(vis_parsing_anno) 225 | # for i in range(1, 16): 226 | mask[vis_parsing_anno>0.5] = 1. 227 | else: 228 | mask = np.ones((h, w)) 229 | return mask -------------------------------------------------------------------------------- /decalib/datasets/vox.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | import scipy 7 | from skimage.io import imread, imsave 8 | from skimage.transform import estimate_transform, warp, resize, rescale 9 | from glob import glob 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 11 | 12 | class VoxelDataset(Dataset): 13 | def __init__(self, K, image_size, scale, trans_scale = 0, dataname='vox2', n_train=100000, isTemporal=False, isEval=False, isSingle=False): 14 | self.K = K 15 | self.image_size = image_size 16 | if dataname == 'vox1': 17 | self.kpt_suffix = '.txt' 18 | self.imagefolder = '/ps/project/face2d3d/VoxCeleb/vox1/dev/images_cropped' 19 | self.kptfolder = '/ps/scratch/yfeng/Data/VoxCeleb/vox1/landmark_2d' 20 | 21 | self.face_dict = {} 22 | for person_id in sorted(os.listdir(self.kptfolder)): 23 | for video_id in os.listdir(os.path.join(self.kptfolder, person_id)): 24 | for face_id in os.listdir(os.path.join(self.kptfolder, person_id, video_id)): 25 | if 'txt' in face_id: 26 | continue 27 | key = person_id + '/' + video_id + '/' + face_id 28 | # if key not in self.face_dict.keys(): 29 | # self.face_dict[key] = [] 30 | name_list = os.listdir(os.path.join(self.kptfolder, person_id, video_id, face_id)) 31 | name_list = [name.split['.'][0] for name in name_list] 32 | if len(name_list)0.5] = 1. 162 | else: 163 | mask = np.ones((h, w)) 164 | return mask 165 | -------------------------------------------------------------------------------- /decalib/models/FLAME.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import pickle 20 | import torch.nn.functional as F 21 | 22 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler 23 | 24 | def to_tensor(array, dtype=torch.float32): 25 | if 'torch.tensor' not in str(type(array)): 26 | return torch.tensor(array, dtype=dtype) 27 | def to_np(array, dtype=np.float32): 28 | if 'scipy.sparse' in str(type(array)): 29 | array = array.todense() 30 | return np.array(array, dtype=dtype) 31 | 32 | class Struct(object): 33 | def __init__(self, **kwargs): 34 | for key, val in kwargs.items(): 35 | setattr(self, key, val) 36 | 37 | class FLAME(nn.Module): 38 | """ 39 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py 40 | Given flame parameters this class generates a differentiable FLAME function 41 | which outputs the a mesh and 2D/3D facial landmarks 42 | """ 43 | def __init__(self, config): 44 | super(FLAME, self).__init__() 45 | print("creating the FLAME Decoder") 46 | with open(config.flame_model_path, 'rb') as f: 47 | ss = pickle.load(f, encoding='latin1') 48 | flame_model = Struct(**ss) 49 | 50 | self.dtype = torch.float32 51 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) 52 | # The vertices of the template model 53 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) 54 | # The shape components and expression 55 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) 56 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2) 57 | self.register_buffer('shapedirs', shapedirs) 58 | # The pose components 59 | num_pose_basis = flame_model.posedirs.shape[-1] 60 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T 61 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) 62 | # 63 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) 64 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1 65 | self.register_buffer('parents', parents) 66 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) 67 | 68 | # Fixing Eyeball and neck rotation 69 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False) 70 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose, 71 | requires_grad=False)) 72 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) 73 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose, 74 | requires_grad=False)) 75 | 76 | # Static and Dynamic Landmark embeddings for FLAME 77 | lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1') 78 | lmk_embeddings = lmk_embeddings[()] 79 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long()) 80 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype)) 81 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long()) 82 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype)) 83 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long()) 84 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype)) 85 | 86 | neck_kin_chain = []; NECK_IDX=1 87 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) 88 | while curr_idx != -1: 89 | neck_kin_chain.append(curr_idx) 90 | curr_idx = self.parents[curr_idx] 91 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) 92 | 93 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx, 94 | dynamic_lmk_b_coords, 95 | neck_kin_chain, dtype=torch.float32): 96 | """ 97 | Selects the face contour depending on the reletive position of the head 98 | Input: 99 | vertices: N X num_of_vertices X 3 100 | pose: N X full pose 101 | dynamic_lmk_faces_idx: The list of contour face indexes 102 | dynamic_lmk_b_coords: The list of contour barycentric weights 103 | neck_kin_chain: The tree to consider for the relative rotation 104 | dtype: Data type 105 | return: 106 | The contour face indexes and the corresponding barycentric weights 107 | """ 108 | 109 | batch_size = pose.shape[0] 110 | 111 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 112 | neck_kin_chain) 113 | rot_mats = batch_rodrigues( 114 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 115 | 116 | rel_rot_mat = torch.eye(3, device=pose.device, 117 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) 118 | for idx in range(len(neck_kin_chain)): 119 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 120 | 121 | y_rot_angle = torch.round( 122 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 123 | max=39)).to(dtype=torch.long) 124 | 125 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 126 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 127 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 128 | y_rot_angle = (neg_mask * neg_vals + 129 | (1 - neg_mask) * y_rot_angle) 130 | 131 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 132 | 0, y_rot_angle) 133 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 134 | 0, y_rot_angle) 135 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 136 | 137 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): 138 | """ 139 | Calculates landmarks by barycentric interpolation 140 | Input: 141 | vertices: torch.tensor NxVx3, dtype = torch.float32 142 | The tensor of input vertices 143 | faces: torch.tensor (N*F)x3, dtype = torch.long 144 | The faces of the mesh 145 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long 146 | The tensor with the indices of the faces used to calculate the 147 | landmarks. 148 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 149 | The tensor of barycentric coordinates that are used to interpolate 150 | the landmarks 151 | 152 | Returns: 153 | landmarks: torch.tensor NxLx3, dtype = torch.float32 154 | The coordinates of the landmarks for each mesh in the batch 155 | """ 156 | # Extract the indices of the vertices for each face 157 | # NxLx3 158 | batch_size, num_verts = vertices.shape[:dd2] 159 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 160 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1) 161 | 162 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to( 163 | device=vertices.device) * num_verts 164 | 165 | lmk_vertices = vertices.view(-1, 3)[lmk_faces] 166 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 167 | return landmarks 168 | 169 | def seletec_3d68(self, vertices): 170 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 171 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1), 172 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) 173 | return landmarks3d 174 | 175 | def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None): 176 | """ 177 | Input: 178 | shape_params: N X number of shape parameters 179 | expression_params: N X number of expression parameters 180 | pose_params: N X number of pose parameters (6) 181 | return:d 182 | vertices: N X V X 3 183 | landmarks: N X number of landmarks X 3 184 | """ 185 | batch_size = shape_params.shape[0] 186 | if pose_params is None: 187 | pose_params = self.eye_pose.expand(batch_size, -1) 188 | if eye_pose_params is None: 189 | eye_pose_params = self.eye_pose.expand(batch_size, -1) 190 | betas = torch.cat([shape_params, expression_params], dim=1) 191 | full_pose = torch.cat([pose_params[:, :3], self.neck_pose.expand(batch_size, -1), pose_params[:, 3:], eye_pose_params], dim=1) 192 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) 193 | 194 | vertices, _ = lbs(betas, full_pose, template_vertices, 195 | self.shapedirs, self.posedirs, 196 | self.J_regressor, self.parents, 197 | self.lbs_weights, dtype=self.dtype) 198 | 199 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) 200 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) 201 | 202 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( 203 | full_pose, self.dynamic_lmk_faces_idx, 204 | self.dynamic_lmk_bary_coords, 205 | self.neck_kin_chain, dtype=self.dtype) 206 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) 207 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) 208 | 209 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor, 210 | lmk_faces_idx, 211 | lmk_bary_coords) 212 | bz = vertices.shape[0] 213 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 214 | self.full_lmk_faces_idx.repeat(bz, 1), 215 | self.full_lmk_bary_coords.repeat(bz, 1, 1)) 216 | return vertices, landmarks2d, landmarks3d 217 | 218 | class FLAMETex(nn.Module): 219 | """ 220 | FLAME texture: 221 | https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64 222 | FLAME texture converted from BFM: 223 | https://github.com/TimoBolkart/BFM_to_FLAME 224 | """ 225 | def __init__(self, config): 226 | super(FLAMETex, self).__init__() 227 | if config.tex_type == 'BFM': 228 | mu_key = 'MU' 229 | pc_key = 'PC' 230 | n_pc = 199 231 | tex_path = config.tex_path 232 | tex_space = np.load(tex_path) 233 | texture_mean = tex_space[mu_key].reshape(1, -1) 234 | texture_basis = tex_space[pc_key].reshape(-1, n_pc) 235 | 236 | elif config.tex_type == 'FLAME': 237 | mu_key = 'mean' 238 | pc_key = 'tex_dir' 239 | n_pc = 200 240 | tex_path = config.tex_path # config.flame_tex_path 241 | tex_space = np.load(tex_path) 242 | texture_mean = tex_space[mu_key].reshape(1, -1)/255. 243 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)/255. 244 | else: 245 | print('texture type ', config.tex_type, 'not exist!') 246 | raise NotImplementedError 247 | 248 | n_tex = config.n_tex 249 | num_components = texture_basis.shape[1] 250 | texture_mean = torch.from_numpy(texture_mean).float()[None,...] 251 | texture_basis = torch.from_numpy(texture_basis[:,:n_tex]).float()[None,...] 252 | self.register_buffer('texture_mean', texture_mean) 253 | self.register_buffer('texture_basis', texture_basis) 254 | 255 | def forward(self, texcode): 256 | ''' 257 | texcode: [batchsize, n_tex] 258 | texture: [bz, 3, 256, 256], range: 0-1 259 | ''' 260 | texture = self.texture_mean + (self.texture_basis*texcode[:,None,:]).sum(-1) 261 | texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0,3,1,2) 262 | texture = F.interpolate(texture, [256, 256]) 263 | texture = texture[:,[2,1,0], :,:] 264 | return texture 265 | -------------------------------------------------------------------------------- /decalib/models/__pycache__/FLAME.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/FLAME.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/models/__pycache__/decoders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/decoders.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/models/__pycache__/encoders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/encoders.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/models/__pycache__/lbs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/lbs.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/models/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/models/decoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | class Generator(nn.Module): 20 | def __init__(self, latent_dim=100, out_channels=1, out_scale=0.01, sample_mode = 'bilinear'): 21 | super(Generator, self).__init__() 22 | self.out_scale = out_scale 23 | 24 | self.init_size = 32 // 4 # Initial size before upsampling 25 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2)) 26 | self.conv_blocks = nn.Sequential( 27 | nn.BatchNorm2d(128), 28 | nn.Upsample(scale_factor=2, mode=sample_mode), #16 29 | nn.Conv2d(128, 128, 3, stride=1, padding=1), 30 | nn.BatchNorm2d(128, 0.8), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Upsample(scale_factor=2, mode=sample_mode), #32 33 | nn.Conv2d(128, 64, 3, stride=1, padding=1), 34 | nn.BatchNorm2d(64, 0.8), 35 | nn.LeakyReLU(0.2, inplace=True), 36 | nn.Upsample(scale_factor=2, mode=sample_mode), #64 37 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 38 | nn.BatchNorm2d(64, 0.8), 39 | nn.LeakyReLU(0.2, inplace=True), 40 | nn.Upsample(scale_factor=2, mode=sample_mode), #128 41 | nn.Conv2d(64, 32, 3, stride=1, padding=1), 42 | nn.BatchNorm2d(32, 0.8), 43 | nn.LeakyReLU(0.2, inplace=True), 44 | nn.Upsample(scale_factor=2, mode=sample_mode), #256 45 | nn.Conv2d(32, 16, 3, stride=1, padding=1), 46 | nn.BatchNorm2d(16, 0.8), 47 | nn.LeakyReLU(0.2, inplace=True), 48 | nn.Conv2d(16, out_channels, 3, stride=1, padding=1), 49 | nn.Tanh(), 50 | ) 51 | 52 | def forward(self, noise): 53 | out = self.l1(noise) 54 | out = out.view(out.shape[0], 128, self.init_size, self.init_size) 55 | img = self.conv_blocks(out) 56 | return img*self.out_scale -------------------------------------------------------------------------------- /decalib/models/encoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import numpy as np 17 | import torch.nn as nn 18 | import torch 19 | import torch.nn.functional as F 20 | from . import resnet 21 | 22 | class ResnetEncoder(nn.Module): 23 | def __init__(self, outsize, last_op=None): 24 | super(ResnetEncoder, self).__init__() 25 | feature_size = 2048 26 | self.encoder = resnet.load_ResNet50Model() #out: 2048 27 | ### regressor 28 | self.layers = nn.Sequential( 29 | nn.Linear(feature_size, 1024), 30 | nn.ReLU(), 31 | nn.Linear(1024, outsize) 32 | ) 33 | self.last_op = last_op 34 | 35 | def forward(self, inputs): 36 | features = self.encoder(inputs) 37 | parameters = self.layers(features) 38 | if self.last_op: 39 | parameters = self.last_op(parameters) 40 | return parameters 41 | -------------------------------------------------------------------------------- /decalib/models/frnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | # from pro_gan_pytorch.PRO_GAN import ProGAN, Generator, Discriminator 5 | import torch.nn.functional as F 6 | import cv2 7 | from torch.autograd import Variable 8 | import math 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(Bottleneck, self).__init__() 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * 4) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | 87 | def __init__(self, block, layers, num_classes=1000, include_top=True): 88 | self.inplanes = 64 89 | super(ResNet, self).__init__() 90 | self.include_top = include_top 91 | 92 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) 96 | 97 | self.layer1 = self._make_layer(block, 64, layers[0]) 98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 101 | self.avgpool = nn.AvgPool2d(7, stride=1) 102 | self.fc = nn.Linear(512 * block.expansion, num_classes) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | 112 | def _make_layer(self, block, planes, blocks, stride=1): 113 | downsample = None 114 | if stride != 1 or self.inplanes != planes * block.expansion: 115 | downsample = nn.Sequential( 116 | nn.Conv2d(self.inplanes, planes * block.expansion, 117 | kernel_size=1, stride=stride, bias=False), 118 | nn.BatchNorm2d(planes * block.expansion), 119 | ) 120 | 121 | layers = [] 122 | layers.append(block(self.inplanes, planes, stride, downsample)) 123 | self.inplanes = planes * block.expansion 124 | for i in range(1, blocks): 125 | layers.append(block(self.inplanes, planes)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | 135 | x = self.layer1(x) 136 | x = self.layer2(x) 137 | x = self.layer3(x) 138 | x = self.layer4(x) 139 | 140 | x = self.avgpool(x) 141 | 142 | if not self.include_top: 143 | return x 144 | 145 | x = x.view(x.size(0), -1) 146 | x = self.fc(x) 147 | return x 148 | 149 | def resnet50(**kwargs): 150 | """Constructs a ResNet-50 model. 151 | """ 152 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 153 | return model 154 | 155 | import pickle 156 | def load_state_dict(model, fname): 157 | """ 158 | Set parameters converted from Caffe models authors of VGGFace2 provide. 159 | See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/. 160 | Arguments: 161 | model: model 162 | fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle. 163 | """ 164 | with open(fname, 'rb') as f: 165 | weights = pickle.load(f, encoding='latin1') 166 | 167 | own_state = model.state_dict() 168 | for name, param in weights.items(): 169 | if name in own_state: 170 | try: 171 | own_state[name].copy_(torch.from_numpy(param)) 172 | except Exception: 173 | raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\ 174 | 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size())) 175 | else: 176 | raise KeyError('unexpected key "{}" in state_dict'.format(name)) 177 | 178 | -------------------------------------------------------------------------------- /decalib/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Soubhik Sanyal 3 | Copyright (c) 2019, Soubhik Sanyal 4 | All rights reserved. 5 | Loads different resnet models 6 | """ 7 | ''' 8 | file: Resnet.py 9 | date: 2018_05_02 10 | author: zhangxiong(1025679612@qq.com) 11 | mark: copied from pytorch source code 12 | ''' 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch 17 | from torch.nn.parameter import Parameter 18 | import torch.optim as optim 19 | import numpy as np 20 | import math 21 | import torchvision 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, block, layers, num_classes=1000): 25 | self.inplanes = 64 26 | super(ResNet, self).__init__() 27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 28 | bias=False) 29 | self.bn1 = nn.BatchNorm2d(64) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 32 | self.layer1 = self._make_layer(block, 64, layers[0]) 33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 36 | self.avgpool = nn.AvgPool2d(7, stride=1) 37 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 38 | 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2. / n)) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | 47 | def _make_layer(self, block, planes, blocks, stride=1): 48 | downsample = None 49 | if stride != 1 or self.inplanes != planes * block.expansion: 50 | downsample = nn.Sequential( 51 | nn.Conv2d(self.inplanes, planes * block.expansion, 52 | kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(planes * block.expansion), 54 | ) 55 | 56 | layers = [] 57 | layers.append(block(self.inplanes, planes, stride, downsample)) 58 | self.inplanes = planes * block.expansion 59 | for i in range(1, blocks): 60 | layers.append(block(self.inplanes, planes)) 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | x = self.conv1(x) 66 | x = self.bn1(x) 67 | x = self.relu(x) 68 | x = self.maxpool(x) 69 | 70 | x = self.layer1(x) 71 | x = self.layer2(x) 72 | x = self.layer3(x) 73 | x1 = self.layer4(x) 74 | 75 | x2 = self.avgpool(x1) 76 | x2 = x2.view(x2.size(0), -1) 77 | # x = self.fc(x) 78 | ## x2: [bz, 2048] for shape 79 | ## x1: [bz, 2048, 7, 7] for texture 80 | return x2 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | 85 | def __init__(self, inplanes, planes, stride=1, downsample=None): 86 | super(Bottleneck, self).__init__() 87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(planes) 89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 90 | padding=1, bias=False) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 93 | self.bn3 = nn.BatchNorm2d(planes * 4) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out += residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | def conv3x3(in_planes, out_planes, stride=1): 121 | """3x3 convolution with padding""" 122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 123 | padding=1, bias=False) 124 | 125 | class BasicBlock(nn.Module): 126 | expansion = 1 127 | 128 | def __init__(self, inplanes, planes, stride=1, downsample=None): 129 | super(BasicBlock, self).__init__() 130 | self.conv1 = conv3x3(inplanes, planes, stride) 131 | self.bn1 = nn.BatchNorm2d(planes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.conv2 = conv3x3(planes, planes) 134 | self.bn2 = nn.BatchNorm2d(planes) 135 | self.downsample = downsample 136 | self.stride = stride 137 | 138 | def forward(self, x): 139 | residual = x 140 | 141 | out = self.conv1(x) 142 | out = self.bn1(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv2(out) 146 | out = self.bn2(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | def copy_parameter_from_resnet(model, resnet_dict): 157 | cur_state_dict = model.state_dict() 158 | # import ipdb; ipdb.set_trace() 159 | for name, param in list(resnet_dict.items())[0:None]: 160 | if name not in cur_state_dict: 161 | # print(name, ' not available in reconstructed resnet') 162 | continue 163 | if isinstance(param, Parameter): 164 | param = param.data 165 | try: 166 | cur_state_dict[name].copy_(param) 167 | except: 168 | # print(name, ' is inconsistent!') 169 | continue 170 | # print('copy resnet state dict finished!') 171 | # import ipdb; ipdb.set_trace() 172 | 173 | def load_ResNet50Model(): 174 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 175 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = False).state_dict()) 176 | return model 177 | 178 | def load_ResNet101Model(): 179 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 180 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict()) 181 | return model 182 | 183 | def load_ResNet152Model(): 184 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 185 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict()) 186 | return model 187 | 188 | # model.load_state_dict(checkpoint['model_state_dict']) 189 | 190 | 191 | ######## Unet 192 | 193 | class DoubleConv(nn.Module): 194 | """(convolution => [BN] => ReLU) * 2""" 195 | 196 | def __init__(self, in_channels, out_channels): 197 | super().__init__() 198 | self.double_conv = nn.Sequential( 199 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 200 | nn.BatchNorm2d(out_channels), 201 | nn.ReLU(inplace=True), 202 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 203 | nn.BatchNorm2d(out_channels), 204 | nn.ReLU(inplace=True) 205 | ) 206 | 207 | def forward(self, x): 208 | return self.double_conv(x) 209 | 210 | 211 | class Down(nn.Module): 212 | """Downscaling with maxpool then double conv""" 213 | 214 | def __init__(self, in_channels, out_channels): 215 | super().__init__() 216 | self.maxpool_conv = nn.Sequential( 217 | nn.MaxPool2d(2), 218 | DoubleConv(in_channels, out_channels) 219 | ) 220 | 221 | def forward(self, x): 222 | return self.maxpool_conv(x) 223 | 224 | 225 | class Up(nn.Module): 226 | """Upscaling then double conv""" 227 | 228 | def __init__(self, in_channels, out_channels, bilinear=True): 229 | super().__init__() 230 | 231 | # if bilinear, use the normal convolutions to reduce the number of channels 232 | if bilinear: 233 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 234 | else: 235 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 236 | 237 | self.conv = DoubleConv(in_channels, out_channels) 238 | 239 | def forward(self, x1, x2): 240 | x1 = self.up(x1) 241 | # input is CHW 242 | diffY = x2.size()[2] - x1.size()[2] 243 | diffX = x2.size()[3] - x1.size()[3] 244 | 245 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 246 | diffY // 2, diffY - diffY // 2]) 247 | # if you have padding issues, see 248 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 249 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 250 | x = torch.cat([x2, x1], dim=1) 251 | return self.conv(x) 252 | 253 | 254 | class OutConv(nn.Module): 255 | def __init__(self, in_channels, out_channels): 256 | super(OutConv, self).__init__() 257 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 258 | 259 | def forward(self, x): 260 | return self.conv(x) -------------------------------------------------------------------------------- /decalib/smirk/__pycache__/mediapipe_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/smirk/__pycache__/mediapipe_utils.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/smirk/__pycache__/smirk_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/smirk/__pycache__/smirk_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/smirk/face_landmarker.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/smirk/face_landmarker.task -------------------------------------------------------------------------------- /decalib/smirk/mediapipe_utils.py: -------------------------------------------------------------------------------- 1 | import mediapipe as mp 2 | from mediapipe.tasks import python 3 | from mediapipe.tasks.python import vision 4 | import cv2 5 | import numpy as np 6 | 7 | base_options = python.BaseOptions(model_asset_path='/home/mengting/projects/diffusionRig/decalib/smirk/face_landmarker.task') 8 | options = vision.FaceLandmarkerOptions(base_options=base_options, 9 | output_face_blendshapes=True, 10 | output_facial_transformation_matrixes=True, 11 | num_faces=1, 12 | min_face_detection_confidence=0.1, 13 | min_face_presence_confidence=0.1 14 | ) 15 | detector = vision.FaceLandmarker.create_from_options(options) 16 | 17 | 18 | def run_mediapipe(image): 19 | # print(image.shape) 20 | image_numpy = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) 21 | 22 | # STEP 3: Load the input image. 23 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_numpy) 24 | 25 | 26 | # STEP 4: Detect face landmarks from the input image. 27 | detection_result = detector.detect(image) 28 | 29 | if len (detection_result.face_landmarks) == 0: 30 | print('No face detected') 31 | return None 32 | 33 | face_landmarks = detection_result.face_landmarks[0] 34 | 35 | face_landmarks_numpy = np.zeros((478, 3)) 36 | 37 | for i, landmark in enumerate(face_landmarks): 38 | face_landmarks_numpy[i] = [landmark.x*image.width, landmark.y*image.height, landmark.z] 39 | 40 | return face_landmarks_numpy 41 | -------------------------------------------------------------------------------- /decalib/smirk/smirk_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import timm 5 | 6 | 7 | def create_backbone(backbone_name, pretrained=True): 8 | backbone = timm.create_model(backbone_name, 9 | pretrained=pretrained, 10 | features_only=True) 11 | feature_dim = backbone.feature_info[-1]['num_chs'] 12 | return backbone, feature_dim 13 | 14 | class PoseEncoder(nn.Module): 15 | def __init__(self) -> None: 16 | super().__init__() 17 | 18 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_small_minimal_100') 19 | 20 | self.pose_cam_layers = nn.Sequential( 21 | nn.Linear(feature_dim, 6) 22 | ) 23 | 24 | self.init_weights() 25 | 26 | def init_weights(self): 27 | self.pose_cam_layers[-1].weight.data *= 0.001 28 | self.pose_cam_layers[-1].bias.data *= 0.001 29 | 30 | self.pose_cam_layers[-1].weight.data[3] = 0 31 | self.pose_cam_layers[-1].bias.data[3] = 7 32 | 33 | 34 | def forward(self, img): 35 | features = self.encoder(img)[-1] 36 | 37 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) 38 | 39 | outputs = {} 40 | 41 | pose_cam = self.pose_cam_layers(features).reshape(img.size(0), -1) 42 | outputs['pose_params'] = pose_cam[...,:3] 43 | outputs['cam'] = pose_cam[...,3:] 44 | 45 | return outputs 46 | 47 | 48 | class ShapeEncoder(nn.Module): 49 | def __init__(self, n_shape=300) -> None: 50 | super().__init__() 51 | 52 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100') 53 | 54 | self.shape_layers = nn.Sequential( 55 | nn.Linear(feature_dim, n_shape) 56 | ) 57 | 58 | self.init_weights() 59 | 60 | 61 | def init_weights(self): 62 | self.shape_layers[-1].weight.data *= 0 63 | self.shape_layers[-1].bias.data *= 0 64 | 65 | 66 | def forward(self, img): 67 | features = self.encoder(img)[-1] 68 | 69 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) 70 | 71 | parameters = self.shape_layers(features).reshape(img.size(0), -1) 72 | 73 | return {'shape_params': parameters} 74 | 75 | 76 | class ExpressionEncoder(nn.Module): 77 | def __init__(self, n_exp=50) -> None: 78 | super().__init__() 79 | 80 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100') 81 | 82 | self.expression_layers = nn.Sequential( 83 | nn.Linear(feature_dim, n_exp+2+3) # num expressions + jaw + eyelid 84 | ) 85 | 86 | self.n_exp = n_exp 87 | self.init_weights() 88 | 89 | 90 | def init_weights(self): 91 | self.expression_layers[-1].weight.data *= 0.1 92 | self.expression_layers[-1].bias.data *= 0.1 93 | 94 | 95 | def forward(self, img): 96 | features = self.encoder(img)[-1] 97 | 98 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) 99 | 100 | 101 | parameters = self.expression_layers(features).reshape(img.size(0), -1) 102 | 103 | outputs = {} 104 | 105 | outputs['expression_params'] = parameters[...,:self.n_exp] 106 | outputs['eyelid_params'] = torch.clamp(parameters[...,self.n_exp:self.n_exp+2], 0, 1) 107 | outputs['jaw_params'] = torch.cat([F.relu(parameters[...,self.n_exp+2].unsqueeze(-1)), 108 | torch.clamp(parameters[...,self.n_exp+3:self.n_exp+5], -.2, .2)], dim=-1) 109 | 110 | return outputs 111 | 112 | 113 | class SmirkEncoder(nn.Module): 114 | def __init__(self, n_exp=50, n_shape=300) -> None: 115 | super().__init__() 116 | 117 | self.pose_encoder = PoseEncoder() 118 | 119 | self.shape_encoder = ShapeEncoder(n_shape=n_shape) 120 | 121 | self.expression_encoder = ExpressionEncoder(n_exp=n_exp) 122 | 123 | def forward(self, img): 124 | pose_outputs = self.pose_encoder(img) 125 | shape_outputs = self.shape_encoder(img) 126 | expression_outputs = self.expression_encoder(img) 127 | 128 | outputs = {} 129 | outputs.update(pose_outputs) 130 | outputs.update(shape_outputs) 131 | outputs.update(expression_outputs) 132 | 133 | return outputs 134 | -------------------------------------------------------------------------------- /decalib/smirk/utils/masking.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import cv2 7 | 8 | 9 | 10 | def load_probabilities_per_FLAME_triangle(): 11 | """ 12 | FLAME_masks_triangles.npy contains for each face area the indices of the triangles that belong to that area. 13 | Using that, we can assign a probability to each triangle based on the area it belongs to, and then sample for masking. 14 | """ 15 | flame_masks_triangles = np.load('assets/FLAME_masks/FLAME_masks_triangles.npy', allow_pickle=True).item() 16 | 17 | area_weights = { 18 | 'neck': 0.0, 19 | 'right_eyeball': 0.0, 20 | 'right_ear': 0.0, 21 | 'lips': 0.5, 22 | 'nose': 0.5, 23 | 'left_ear': 0.0, 24 | 'eye_region': 1.0, 25 | 'forehead':1.0, 26 | 'left_eye_region': 1.0, 27 | 'right_eye_region': 1.0, 28 | 'face_clean': 1.0, 29 | 'cleaner_lips': 1.0 30 | } 31 | 32 | face_probabilities = torch.zeros(9976) 33 | 34 | for area in area_weights.keys(): 35 | face_probabilities[flame_masks_triangles[area]] = area_weights[area] 36 | 37 | return face_probabilities 38 | 39 | 40 | def triangle_area(vertices): 41 | # Using the Shoelace formula to calculate the area of triangles in the xy plane 42 | # vertices is expected to be of shape (..., 3, 2) where the last dimension holds x and y coordinates. 43 | x1, y1 = vertices[..., 0, 0], vertices[..., 0, 1] 44 | x2, y2 = vertices[..., 1, 0], vertices[..., 1, 1] 45 | x3, y3 = vertices[..., 2, 0], vertices[..., 2, 1] 46 | 47 | # Shoelace formula for the area of a triangle given by coordinates (x1, y1), (x2, y2), (x3, y3) 48 | area = 0.5 * torch.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3) 49 | return area 50 | 51 | 52 | 53 | def random_barycentric(num=1): 54 | # Generate two random numbers for each set 55 | u, v = torch.rand(num), torch.rand(num) 56 | 57 | # Adjust the random numbers if they are outside the triangle 58 | outside_triangle = u + v > 1 59 | u[outside_triangle], v[outside_triangle] = 1 - u[outside_triangle], 1 - v[outside_triangle] 60 | 61 | # Calculate the barycentric coordinates 62 | alpha = 1 - (u + v) 63 | beta = u 64 | gamma = v 65 | 66 | # Combine the coordinates into a single tensor 67 | return torch.stack((alpha, beta, gamma), dim=1) 68 | 69 | 70 | def masking(img, mask, extra_points, wr=15, rendered_mask=None, extra_noise=True, random_mask=0.01): 71 | # img: B x C x H x W 72 | # mask: B x 1 x H x W 73 | 74 | B, C, H, W = img.size() 75 | 76 | # dilate face mask, drawn from convex hull of face landmarks 77 | mask = 1-F.max_pool2d(1-mask, 2 * wr + 1, stride=1, padding=wr) 78 | 79 | # optionally remove the rendered mask 80 | if rendered_mask is not None: 81 | mask = mask * (1 - rendered_mask) 82 | 83 | masked_img = img * mask 84 | # add noise to extra in-face points 85 | if extra_noise: 86 | # normal around 1 with std 0.1 87 | noise_mult = torch.randn(extra_points.shape).to(img.device) * 0.05 + 1 88 | extra_points = extra_points * noise_mult 89 | 90 | # select random_mask percentage of pixels as centers to crop out 11x11 patches 91 | if random_mask > 0: 92 | random_mask = torch.bernoulli(torch.ones((B, 1, H, W)) * random_mask).to(img.device) 93 | # dilate the mask to have 11x11 patches 94 | random_mask = 1 - F.max_pool2d(random_mask, 11, stride=1, padding=5) 95 | 96 | extra_points = extra_points * random_mask 97 | 98 | masked_img[extra_points > 0] = extra_points[extra_points > 0] 99 | 100 | masked_img = masked_img.detach() 101 | return masked_img 102 | 103 | 104 | 105 | def point2ind(npoints, H): 106 | 107 | npoints = npoints * (H // 2) + H // 2 108 | npoints = npoints.long() 109 | npoints[...,1] = torch.clamp(npoints[..., 1], 0, H-1) 110 | npoints[...,0] = torch.clamp(npoints[..., 0], 0, H-1) 111 | 112 | return npoints 113 | 114 | 115 | def transfer_pixels(img, points1, points2, rbound=None): 116 | 117 | B, C, H, W = img.size() 118 | retained_pixels = torch.zeros_like(img).to(img.device) 119 | 120 | if rbound is not None: 121 | for bi in range(B): 122 | retained_pixels[bi, :, points2[bi, :rbound[bi], 1], points2[bi, :rbound[bi], 0]] = \ 123 | img[bi, :, points1[bi, :rbound[bi], 1], points1[bi, :rbound[bi], 0]] 124 | else: 125 | retained_pixels[torch.arange(B).unsqueeze(-1), :, points2[..., 1], points2[..., 0]] = \ 126 | img[torch.arange(B).unsqueeze(-1), :, points1[..., 1], points1[..., 0]] 127 | 128 | return retained_pixels 129 | 130 | 131 | def mesh_based_mask_uniform_faces(flame_trans_verts, flame_faces, face_probabilities, mask_ratio=0.1, coords=None, IMAGE_SIZE=224): 132 | """ 133 | This function samples points from the FLAME mesh based on the face probabilities and the mask ratio. 134 | """ 135 | batch_size = flame_trans_verts.size(0) 136 | DEVICE = flame_trans_verts.device 137 | 138 | # if mask_ratio is single value, then use it as a ratio of the image size 139 | num_points_to_sample = int(mask_ratio * IMAGE_SIZE * IMAGE_SIZE) 140 | 141 | flame_faces_expanded = flame_faces.expand(batch_size, -1, -1) 142 | 143 | if coords is None: 144 | # calculate face normals 145 | transformed_normals = vertex_normals(flame_trans_verts, flame_faces_expanded) 146 | transformed_face_normals = face_vertices(transformed_normals, flame_faces_expanded) 147 | transformed_face_normals = transformed_face_normals[:,:,:,2].mean(dim=-1) 148 | face_probabilities = face_probabilities.repeat(batch_size,1).to(flame_trans_verts.device) 149 | 150 | # # where the face normals are negative, set probability to 0 151 | face_probabilities = torch.where(transformed_face_normals < 0.05, face_probabilities, torch.zeros_like(transformed_face_normals).to(DEVICE)) 152 | # face_probabilities = torch.where(transformed_face_normals > 0, torch.ones_like(transformed_face_normals).to(flame_trans_verts.device), face_probabilities) 153 | 154 | # calculate xy area of faces and scale the probabilities by it 155 | fv = face_vertices(flame_trans_verts, flame_faces_expanded) 156 | xy_area = triangle_area(fv) 157 | 158 | face_probabilities = face_probabilities * xy_area 159 | 160 | 161 | sampled_faces_indices = torch.multinomial(face_probabilities, num_points_to_sample, replacement=True).to(DEVICE) 162 | 163 | barycentric_coords = random_barycentric(num=batch_size*num_points_to_sample).to(DEVICE) 164 | barycentric_coords = barycentric_coords.view(batch_size, num_points_to_sample, 3) 165 | else: 166 | sampled_faces_indices = coords['sampled_faces_indices'] 167 | barycentric_coords = coords['barycentric_coords'] 168 | 169 | npoints = vertices2landmarks(flame_trans_verts, flame_faces, sampled_faces_indices, barycentric_coords) 170 | 171 | npoints = .5 * (1 + npoints) * IMAGE_SIZE 172 | npoints = npoints.long() 173 | npoints[...,1] = torch.clamp(npoints[..., 1], 0, IMAGE_SIZE-1) 174 | npoints[...,0] = torch.clamp(npoints[..., 0], 0, IMAGE_SIZE-1) 175 | 176 | #mask = torch.zeros((flame_output['trans_verts'].size(0), 1, self.config.image_size, self.config.image_size)).to(flame_output['trans_verts'].device) 177 | 178 | #mask[torch.arange(batch_size).unsqueeze(-1), :, npoints[..., 1], npoints[..., 0]] = 1 179 | 180 | return npoints, {'sampled_faces_indices':sampled_faces_indices, 'barycentric_coords':barycentric_coords} 181 | -------------------------------------------------------------------------------- /decalib/smirk/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | def load_templates(): 6 | templates_path = "assets/expression_templates_famos" 7 | classes_to_load = ["lips_back", "rolling_lips", "mouth_side", "kissing", "high_smile", "mouth_up", 8 | "mouth_middle", "mouth_down", "blow_cheeks", "cheeks_in", "jaw", "lips_up"] 9 | templates = {} 10 | for subject in os.listdir(templates_path): 11 | if os.path.isdir(os.path.join(templates_path, subject)): 12 | for template in os.listdir(os.path.join(templates_path, subject)): 13 | if template.endswith(".mp4"): 14 | continue 15 | if template not in classes_to_load: 16 | continue 17 | exps = [] 18 | for npy_file in os.listdir(os.path.join(templates_path, subject, template)): 19 | params = np.load(os.path.join(templates_path, subject, template, npy_file), allow_pickle=True) 20 | exp = params.item()['expression'].squeeze() 21 | exps.append(exp) 22 | templates[subject+template] = np.array(exps) 23 | print('Number of expression templates loaded: ', len(templates.keys())) 24 | 25 | return templates 26 | 27 | 28 | 29 | def tensor_to_image(image_tensor): 30 | """Converts a tensor to a numpy image.""" 31 | image = image_tensor.permute(1,2,0).cpu().numpy()*255.0 32 | image = np.clip(image, 0, 255) 33 | image = image.astype(np.uint8) 34 | return image 35 | 36 | def image_to_tensor(image): 37 | """Converts a numpy image to a tensor.""" 38 | image = torch.from_numpy(image).permute(2,0,1).float()/255.0 39 | return image 40 | 41 | 42 | def count_parameters(model): 43 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 44 | 45 | 46 | def freeze_module(module, module_name=None): 47 | 48 | for param in module.parameters(): 49 | param.requires_grad_(False) 50 | 51 | module.eval() 52 | 53 | 54 | def unfreeze_module(module, module_name=None): 55 | 56 | for param in module.parameters(): 57 | param.requires_grad_(True) 58 | 59 | module.train() 60 | 61 | import cv2 62 | from torchvision.utils import make_grid 63 | 64 | 65 | def batch_draw_keypoints(images, landmarks, color=(255, 255, 255), radius=1): 66 | if isinstance(landmarks, torch.Tensor): 67 | landmarks = landmarks.cpu().numpy() 68 | landmarks = landmarks.copy()*112 + 112 69 | 70 | if isinstance(images, torch.Tensor): 71 | images = images.cpu().numpy().transpose(0, 2, 3, 1) 72 | images = (images * 255).astype('uint8') 73 | images = np.ascontiguousarray(images[..., ::-1]) 74 | 75 | plotted_images = [] 76 | for image, landmark in zip(images, landmarks): 77 | for point in landmark: 78 | image = cv2.circle(image, (int(point[0]), int(point[1])), radius, color, -1) 79 | plotted_images.append(image) 80 | 81 | return plotted_images 82 | 83 | def make_grid_from_opencv_images(images, nrow=12): 84 | """ Create a grid of images from the list of cv2 images in images""" 85 | images = np.array(images) 86 | images = images[..., ::-1] 87 | images = np.array(images) 88 | images = torch.from_numpy(images).permute(0, 3, 1, 2).float()/255. 89 | grid = make_grid(images, nrow=nrow) 90 | return grid -------------------------------------------------------------------------------- /decalib/utils/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/utils/__pycache__/renderer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/renderer.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/utils/__pycache__/rotation_converter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/rotation_converter.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/utils/__pycache__/tensor_cropper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/tensor_cropper.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/utils/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /decalib/utils/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Default config for DECA 3 | ''' 4 | from yacs.config import CfgNode as CN 5 | import argparse 6 | import yaml 7 | import os 8 | 9 | cfg = CN() 10 | 11 | abs_deca_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) 12 | abs_deca_dir = os.path.join(abs_deca_dir, 'decalib') 13 | cfg.deca_dir = abs_deca_dir 14 | cfg.device = 'cuda' 15 | cfg.device_id = '0' 16 | 17 | cfg.pretrained_modelpath = os.path.join(cfg.deca_dir, 'data', 'deca_model.tar') 18 | cfg.output_dir = '' 19 | cfg.rasterizer_type = 'pytorch3d' 20 | # ---------------------------------------------------------------------------- # 21 | # Options for Face model 22 | # ---------------------------------------------------------------------------- # 23 | cfg.model = CN() 24 | cfg.model.topology_path = os.path.join(cfg.deca_dir, 'data', 'head_template.obj') 25 | # texture data original from http://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_texture_data.zip 26 | cfg.model.dense_template_path = os.path.join(cfg.deca_dir, 'data', 'texture_data_256.npy') 27 | cfg.model.fixed_displacement_path = os.path.join(cfg.deca_dir, 'data', 'fixed_displacement_256.npy') 28 | cfg.model.flame_model_path = os.path.join(cfg.deca_dir, 'data', 'generic_model.pkl') 29 | cfg.model.flame_lmk_embedding_path = os.path.join(cfg.deca_dir, 'data', 'landmark_embedding.npy') 30 | cfg.model.face_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_mask.png') 31 | cfg.model.face_eye_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_eye_mask.png') 32 | cfg.model.mean_tex_path = os.path.join(cfg.deca_dir, 'data', 'mean_texture.jpg') 33 | cfg.model.tex_path = os.path.join(cfg.deca_dir, 'data', 'FLAME_albedo_from_BFM.npz') 34 | cfg.model.tex_type = 'BFM' # BFM, FLAME, albedoMM 35 | cfg.model.uv_size = 256 36 | cfg.model.param_list = ['shape', 'tex', 'exp', 'pose', 'cam', 'light'] 37 | cfg.model.n_shape = 100 38 | cfg.model.n_tex = 50 39 | cfg.model.n_exp = 50 40 | cfg.model.n_cam = 3 41 | cfg.model.n_pose = 6 42 | cfg.model.n_light = 27 43 | cfg.model.use_tex = True 44 | cfg.model.jaw_type = 'aa' # default use axis angle, another option: euler. Note that: aa is not stable in the beginning 45 | # face recognition model 46 | cfg.model.fr_model_path = os.path.join(cfg.deca_dir, 'data', 'resnet50_ft_weight.pkl') 47 | 48 | ## details 49 | cfg.model.n_detail = 128 50 | cfg.model.max_z = 0.01 51 | 52 | # ---------------------------------------------------------------------------- # 53 | # Options for Dataset 54 | # ---------------------------------------------------------------------------- # 55 | cfg.dataset = CN() 56 | cfg.dataset.training_data = ['vggface2', 'ethnicity'] 57 | # cfg.dataset.training_data = ['ethnicity'] 58 | cfg.dataset.eval_data = ['aflw2000'] 59 | cfg.dataset.test_data = [''] 60 | cfg.dataset.batch_size = 2 61 | cfg.dataset.K = 4 62 | cfg.dataset.isSingle = False 63 | cfg.dataset.num_workers = 2 64 | cfg.dataset.image_size = 224 65 | cfg.dataset.scale_min = 1.4 66 | cfg.dataset.scale_max = 1.8 67 | cfg.dataset.trans_scale = 0. 68 | 69 | # ---------------------------------------------------------------------------- # 70 | # Options for training 71 | # ---------------------------------------------------------------------------- # 72 | cfg.train = CN() 73 | cfg.train.train_detail = False 74 | cfg.train.max_epochs = 500 75 | cfg.train.max_steps = 1000000 76 | cfg.train.lr = 1e-4 77 | cfg.train.log_dir = 'logs' 78 | cfg.train.log_steps = 10 79 | cfg.train.vis_dir = 'train_images' 80 | cfg.train.vis_steps = 200 81 | cfg.train.write_summary = True 82 | cfg.train.checkpoint_steps = 500 83 | cfg.train.val_steps = 500 84 | cfg.train.val_vis_dir = 'val_images' 85 | cfg.train.eval_steps = 5000 86 | cfg.train.resume = True 87 | 88 | # ---------------------------------------------------------------------------- # 89 | # Options for Losses 90 | # ---------------------------------------------------------------------------- # 91 | cfg.loss = CN() 92 | cfg.loss.lmk = 1.0 93 | cfg.loss.useWlmk = True 94 | cfg.loss.eyed = 1.0 95 | cfg.loss.lipd = 0.5 96 | cfg.loss.photo = 2.0 97 | cfg.loss.useSeg = True 98 | cfg.loss.id = 0.2 99 | cfg.loss.id_shape_only = True 100 | cfg.loss.reg_shape = 1e-04 101 | cfg.loss.reg_exp = 1e-04 102 | cfg.loss.reg_tex = 1e-04 103 | cfg.loss.reg_light = 1. 104 | cfg.loss.reg_jaw_pose = 0. #1. 105 | cfg.loss.use_gender_prior = False 106 | cfg.loss.shape_consistency = True 107 | # loss for detail 108 | cfg.loss.detail_consistency = True 109 | cfg.loss.useConstraint = True 110 | cfg.loss.mrf = 5e-2 111 | cfg.loss.photo_D = 2. 112 | cfg.loss.reg_sym = 0.005 113 | cfg.loss.reg_z = 0.005 114 | cfg.loss.reg_diff = 0.005 115 | 116 | 117 | def get_cfg_defaults(): 118 | """Get a yacs CfgNode object with default values for my_project.""" 119 | # Return a clone so that the defaults will not be altered 120 | # This is for the "local variable" use pattern 121 | return cfg.clone() 122 | 123 | def update_cfg(cfg, cfg_file): 124 | cfg.merge_from_file(cfg_file) 125 | return cfg.clone() 126 | 127 | def parse_args(): 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--cfg', type=str, help='cfg file path') 130 | parser.add_argument('--mode', type=str, default = 'train', help='deca mode') 131 | 132 | args = parser.parse_args() 133 | print(args, end='\n\n') 134 | 135 | cfg = get_cfg_defaults() 136 | cfg.cfg_file = None 137 | cfg.mode = args.mode 138 | # import ipdb; ipdb.set_trace() 139 | if args.cfg is not None: 140 | cfg_file = args.cfg 141 | cfg = update_cfg(cfg, args.cfg) 142 | cfg.cfg_file = cfg_file 143 | 144 | return cfg 145 | -------------------------------------------------------------------------------- /decalib/utils/rasterizer/INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Install 2 | from standard_rasterize_cuda import standard_rasterize 3 | # from .rasterizer.standard_rasterize_cuda import standard_rasterize 4 | 5 | in this folder, run 6 | ```python setup.py build_ext -i ``` 7 | 8 | then remember to set --rasterizer_type=standard when runing demos :) 9 | 10 | ## Alg 11 | https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation 12 | 13 | ## Speed Comparison 14 | runtime for raterization only 15 | In PIXIE, number of faces in SMPLX: 20908 16 | 17 | for image size = 1024 18 | pytorch3d: 0.031s 19 | standard: 0.01s 20 | 21 | for image size = 224 22 | pytorch3d: 0.0035s 23 | standard: 0.0014s 24 | 25 | why standard rasterizer is faster than pytorch3d? 26 | Ref: https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu 27 | pytorch3d: for each pixel in image space (each pixel is parallel in cuda), loop through the faces, check if this pixel is in the projection bounding box of the face, then sorting faces according to z, record the face id of closest K faces. 28 | standard rasterization: for each face in mesh (each face is parallel in cuda), loop through pixels in the projection bounding box (normally a very samll number), compare z, record face id of that pixel 29 | 30 | -------------------------------------------------------------------------------- /decalib/utils/rasterizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/rasterizer/__init__.py -------------------------------------------------------------------------------- /decalib/utils/rasterizer/setup.py: -------------------------------------------------------------------------------- 1 | # To install, run 2 | # python setup.py build_ext -i 3 | # Ref: https://github.com/pytorch/pytorch/blob/11a40410e755b1fe74efe9eaa635e7ba5712846b/test/cpp_extensions/setup.py#L62 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | import os 8 | 9 | # USE_NINJA = os.getenv('USE_NINJA') == '1' 10 | os.environ["CC"] = "gcc-7" 11 | os.environ["CXX"] = "gcc-7" 12 | 13 | USE_NINJA = os.getenv('USE_NINJA') == '1' 14 | 15 | setup( 16 | name='standard_rasterize_cuda', 17 | ext_modules=[ 18 | CUDAExtension('standard_rasterize_cuda', [ 19 | 'standard_rasterize_cuda.cpp', 20 | 'standard_rasterize_cuda_kernel.cu', 21 | ]) 22 | ], 23 | cmdclass={'build_ext': BuildExtension.with_options(use_ninja=USE_NINJA)} 24 | ) 25 | -------------------------------------------------------------------------------- /decalib/utils/rasterizer/standard_rasterize_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | std::vector forward_rasterize_cuda( 6 | at::Tensor face_vertices, 7 | at::Tensor depth_buffer, 8 | at::Tensor triangle_buffer, 9 | at::Tensor baryw_buffer, 10 | int h, 11 | int w); 12 | 13 | std::vector standard_rasterize( 14 | at::Tensor face_vertices, 15 | at::Tensor depth_buffer, 16 | at::Tensor triangle_buffer, 17 | at::Tensor baryw_buffer, 18 | int height, int width 19 | ) { 20 | return forward_rasterize_cuda(face_vertices, depth_buffer, triangle_buffer, baryw_buffer, height, width); 21 | } 22 | 23 | std::vector forward_rasterize_colors_cuda( 24 | at::Tensor face_vertices, 25 | at::Tensor face_colors, 26 | at::Tensor depth_buffer, 27 | at::Tensor triangle_buffer, 28 | at::Tensor images, 29 | int h, 30 | int w); 31 | 32 | std::vector standard_rasterize_colors( 33 | at::Tensor face_vertices, 34 | at::Tensor face_colors, 35 | at::Tensor depth_buffer, 36 | at::Tensor triangle_buffer, 37 | at::Tensor images, 38 | int height, int width 39 | ) { 40 | return forward_rasterize_colors_cuda(face_vertices, face_colors, depth_buffer, triangle_buffer, images, height, width); 41 | } 42 | 43 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 44 | m.def("standard_rasterize", &standard_rasterize, "RASTERIZE (CUDA)"); 45 | m.def("standard_rasterize_colors", &standard_rasterize_colors, "RASTERIZE COLORS (CUDA)"); 46 | } 47 | 48 | // TODO: backward -------------------------------------------------------------------------------- /decalib/utils/rasterizer/standard_rasterize_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | // Ref: https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/cuda/rasterize_cuda_kernel.cu 2 | // https://github.com/YadiraF/face3d/blob/master/face3d/mesh/cython/mesh_core.cpp 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | namespace{ 10 | __device__ __forceinline__ float atomicMin(float* address, float val) 11 | { 12 | int* address_as_i = (int*) address; 13 | int old = *address_as_i, assumed; 14 | do { 15 | assumed = old; 16 | old = atomicCAS(address_as_i, assumed, 17 | __float_as_int(fminf(val, __int_as_float(assumed)))); 18 | } while (assumed != old); 19 | return __int_as_float(old); 20 | } 21 | __device__ __forceinline__ double atomicMin(double* address, double val) 22 | { 23 | unsigned long long int* address_as_i = (unsigned long long int*) address; 24 | unsigned long long int old = *address_as_i, assumed; 25 | do { 26 | assumed = old; 27 | old = atomicCAS(address_as_i, assumed, 28 | __double_as_longlong(fminf(val, __longlong_as_double(assumed)))); 29 | } while (assumed != old); 30 | return __longlong_as_double(old); 31 | } 32 | 33 | template 34 | __device__ __forceinline__ bool check_face_frontside(const scalar_t *face) { 35 | return (face[7] - face[1]) * (face[3] - face[0]) < (face[4] - face[1]) * (face[6] - face[0]); 36 | } 37 | 38 | 39 | template struct point 40 | { 41 | public: 42 | scalar_t x; 43 | scalar_t y; 44 | 45 | __host__ __device__ scalar_t dot(point p) 46 | { 47 | return this->x * p.x + this->y * p.y; 48 | }; 49 | 50 | __host__ __device__ point operator-(point& p) 51 | { 52 | point np; 53 | np.x = this->x - p.x; 54 | np.y = this->y - p.y; 55 | return np; 56 | }; 57 | 58 | __host__ __device__ point operator+(point& p) 59 | { 60 | point np; 61 | np.x = this->x + p.x; 62 | np.y = this->y + p.y; 63 | return np; 64 | }; 65 | 66 | __host__ __device__ point operator*(scalar_t s) 67 | { 68 | point np; 69 | np.x = s * this->x; 70 | np.y = s * this->y; 71 | return np; 72 | }; 73 | }; 74 | 75 | template 76 | __device__ __forceinline__ bool check_pixel_inside(const scalar_t *w) { 77 | return w[0] <= 1 && w[0] >= 0 && w[1] <= 1 && w[1] >= 0 && w[2] <= 1 && w[2] >= 0; 78 | } 79 | 80 | template 81 | __device__ __forceinline__ void barycentric_weight(scalar_t *w, point p, point p0, point p1, point p2) { 82 | 83 | // vectors 84 | point v0, v1, v2; 85 | scalar_t s = p.dot(p); 86 | v0 = p2 - p0; 87 | v1 = p1 - p0; 88 | v2 = p - p0; 89 | 90 | // dot products 91 | scalar_t dot00 = v0.dot(v0); //v0.x * v0.x + v0.y * v0.y //np.dot(v0.T, v0) 92 | scalar_t dot01 = v0.dot(v1); //v0.x * v1.x + v0.y * v1.y //np.dot(v0.T, v1) 93 | scalar_t dot02 = v0.dot(v2); //v0.x * v2.x + v0.y * v2.y //np.dot(v0.T, v2) 94 | scalar_t dot11 = v1.dot(v1); //v1.x * v1.x + v1.y * v1.y //np.dot(v1.T, v1) 95 | scalar_t dot12 = v1.dot(v2); //v1.x * v2.x + v1.y * v2.y//np.dot(v1.T, v2) 96 | 97 | // barycentric coordinates 98 | scalar_t inverDeno; 99 | if(dot00*dot11 - dot01*dot01 == 0) 100 | inverDeno = 0; 101 | else 102 | inverDeno = 1/(dot00*dot11 - dot01*dot01); 103 | 104 | scalar_t u = (dot11*dot02 - dot01*dot12)*inverDeno; 105 | scalar_t v = (dot00*dot12 - dot01*dot02)*inverDeno; 106 | 107 | // weight 108 | w[0] = 1 - u - v; 109 | w[1] = v; 110 | w[2] = u; 111 | } 112 | 113 | // Ref: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation/overview-rasterization-algorithm 114 | template 115 | __global__ void forward_rasterize_cuda_kernel( 116 | const scalar_t* __restrict__ face_vertices, //[bz, nf, 3, 3] 117 | scalar_t* depth_buffer, 118 | int* triangle_buffer, 119 | scalar_t* baryw_buffer, 120 | int batch_size, int h, int w, 121 | int ntri) { 122 | 123 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 124 | if (i >= batch_size * ntri) { 125 | return; 126 | } 127 | int bn = i/ntri; 128 | const scalar_t* face = &face_vertices[i * 9]; 129 | scalar_t bw[3]; 130 | point p0, p1, p2, p; 131 | 132 | p0.x = face[0]; p0.y=face[1]; 133 | p1.x = face[3]; p1.y=face[4]; 134 | p2.x = face[6]; p2.y=face[7]; 135 | 136 | int x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0); 137 | int x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1); 138 | int y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0); 139 | int y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1); 140 | 141 | for(int y = y_min; y <= y_max; y++) //h 142 | { 143 | for(int x = x_min; x <= x_max; x++) //w 144 | { 145 | p.x = x; p.y = y; 146 | barycentric_weight(bw, p, p0, p1, p2); 147 | // if(((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) && check_face_frontside(face)) 148 | if((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) 149 | { 150 | // perspective correct: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation/perspective-correct-interpolation-vertex-attributes 151 | scalar_t zp = 1. / (bw[0] / face[2] + bw[1] / face[5] + bw[2] / face[8]); 152 | // printf("%f %f %f \n", (float)zp, (float)face[2], (float)bw[2]); 153 | atomicMin(&depth_buffer[bn*h*w + y*w + x], zp); 154 | if(depth_buffer[bn*h*w + y*w + x] == zp) 155 | { 156 | triangle_buffer[bn*h*w + y*w + x] = (int)(i%ntri); 157 | for(int k=0; k<3; k++){ 158 | baryw_buffer[bn*h*w*3 + y*w*3 + x*3 + k] = bw[k]; 159 | } 160 | } 161 | } 162 | } 163 | } 164 | 165 | } 166 | 167 | template 168 | __global__ void forward_rasterize_colors_cuda_kernel( 169 | const scalar_t* __restrict__ face_vertices, //[bz, nf, 3, 3] 170 | const scalar_t* __restrict__ face_colors, //[bz, nf, 3, 3] 171 | scalar_t* depth_buffer, 172 | int* triangle_buffer, 173 | scalar_t* images, 174 | int batch_size, int h, int w, 175 | int ntri) { 176 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 177 | if (i >= batch_size * ntri) { 178 | return; 179 | } 180 | int bn = i/ntri; 181 | const scalar_t* face = &face_vertices[i * 9]; 182 | const scalar_t* color = &face_colors[i * 9]; 183 | scalar_t bw[3]; 184 | point p0, p1, p2, p; 185 | 186 | p0.x = face[0]; p0.y=face[1]; 187 | p1.x = face[3]; p1.y=face[4]; 188 | p2.x = face[6]; p2.y=face[7]; 189 | scalar_t cl[3][3]; 190 | for (int num = 0; num < 3; num++) { 191 | for (int dim = 0; dim < 3; dim++) { 192 | cl[num][dim] = color[3 * num + dim]; //[3p,3rgb] 193 | } 194 | } 195 | int x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0); 196 | int x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1); 197 | int y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0); 198 | int y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1); 199 | 200 | for(int y = y_min; y <= y_max; y++) //h 201 | { 202 | for(int x = x_min; x <= x_max; x++) //w 203 | { 204 | p.x = x; p.y = y; 205 | barycentric_weight(bw, p, p0, p1, p2); 206 | if(((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) && check_face_frontside(face)) 207 | // if((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) 208 | { 209 | scalar_t zp = 1. / (bw[0] / face[2] + bw[1] / face[5] + bw[2] / face[8]); 210 | 211 | atomicMin(&depth_buffer[bn*h*w + y*w + x], zp); 212 | if(depth_buffer[bn*h*w + y*w + x] == zp) 213 | { 214 | triangle_buffer[bn*h*w + y*w + x] = (int)(i%ntri); 215 | for(int k=0; k<3; k++){ 216 | // baryw_buffer[bn*h*w*3 + y*w*3 + x*3 + k] = bw[k]; 217 | images[bn*h*w*3 + y*w*3 + x*3 + k] = bw[0]*cl[0][k] + bw[1]*cl[1][k] + bw[2]*cl[2][k]; 218 | } 219 | // buffers[bn*h*w*2 + y*w*2 + x*2 + 1] = p_depth; 220 | } 221 | } 222 | } 223 | } 224 | 225 | } 226 | 227 | } 228 | 229 | std::vector forward_rasterize_cuda( 230 | at::Tensor face_vertices, 231 | at::Tensor depth_buffer, 232 | at::Tensor triangle_buffer, 233 | at::Tensor baryw_buffer, 234 | int h, 235 | int w){ 236 | 237 | const auto batch_size = face_vertices.size(0); 238 | const auto ntri = face_vertices.size(1); 239 | 240 | // print(channel_size) 241 | const int threads = 512; 242 | const dim3 blocks_1 ((batch_size * ntri - 1) / threads +1); 243 | 244 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_cuda1", ([&] { 245 | forward_rasterize_cuda_kernel<<>>( 246 | face_vertices.data(), 247 | depth_buffer.data(), 248 | triangle_buffer.data(), 249 | baryw_buffer.data(), 250 | batch_size, h, w, 251 | ntri); 252 | })); 253 | 254 | // better to do it twice (or there will be balck spots in the rendering) 255 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_cuda2", ([&] { 256 | forward_rasterize_cuda_kernel<<>>( 257 | face_vertices.data(), 258 | depth_buffer.data(), 259 | triangle_buffer.data(), 260 | baryw_buffer.data(), 261 | batch_size, h, w, 262 | ntri); 263 | })); 264 | cudaError_t err = cudaGetLastError(); 265 | if (err != cudaSuccess) 266 | printf("Error in forward_rasterize_cuda_kernel: %s\n", cudaGetErrorString(err)); 267 | 268 | return {depth_buffer, triangle_buffer, baryw_buffer}; 269 | } 270 | 271 | 272 | std::vector forward_rasterize_colors_cuda( 273 | at::Tensor face_vertices, 274 | at::Tensor face_colors, 275 | at::Tensor depth_buffer, 276 | at::Tensor triangle_buffer, 277 | at::Tensor images, 278 | int h, 279 | int w){ 280 | 281 | const auto batch_size = face_vertices.size(0); 282 | const auto ntri = face_vertices.size(1); 283 | 284 | // print(channel_size) 285 | const int threads = 512; 286 | const dim3 blocks_1 ((batch_size * ntri - 1) / threads +1); 287 | //initial 288 | 289 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_colors_cuda", ([&] { 290 | forward_rasterize_colors_cuda_kernel<<>>( 291 | face_vertices.data(), 292 | face_colors.data(), 293 | depth_buffer.data(), 294 | triangle_buffer.data(), 295 | images.data(), 296 | batch_size, h, w, 297 | ntri); 298 | })); 299 | // better to do it twice 300 | // AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_colors_cuda", ([&] { 301 | // forward_rasterize_colors_cuda_kernel<<>>( 302 | // face_vertices.data(), 303 | // face_colors.data(), 304 | // depth_buffer.data(), 305 | // triangle_buffer.data(), 306 | // images.data(), 307 | // batch_size, h, w, 308 | // ntri); 309 | // })); 310 | cudaError_t err = cudaGetLastError(); 311 | if (err != cudaSuccess) 312 | printf("Error in forward_rasterize_cuda_kernel: %s\n", cudaGetErrorString(err)); 313 | 314 | return {depth_buffer, triangle_buffer, images}; 315 | } 316 | 317 | 318 | 319 | 320 | -------------------------------------------------------------------------------- /decalib/utils/tensor_cropper.py: -------------------------------------------------------------------------------- 1 | ''' 2 | crop 3 | for torch tensor 4 | Given image, bbox(center, bboxsize) 5 | return: cropped image, tform(used for transform the keypoint accordingly) 6 | only support crop to squared images 7 | ''' 8 | import torch 9 | from kornia.geometry.transform.imgwarp import ( 10 | warp_perspective, get_perspective_transform, warp_affine 11 | ) 12 | 13 | def points2bbox(points, points_scale=None): 14 | if points_scale: 15 | assert points_scale[0]==points_scale[1] 16 | points = points.clone() 17 | points[:,:,:2] = (points[:,:,:2]*0.5 + 0.5)*points_scale[0] 18 | min_coords, _ = torch.min(points, dim=1) 19 | xmin, ymin = min_coords[:, 0], min_coords[:, 1] 20 | max_coords, _ = torch.max(points, dim=1) 21 | xmax, ymax = max_coords[:, 0], max_coords[:, 1] 22 | center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5 23 | 24 | width = (xmax - xmin) 25 | height = (ymax - ymin) 26 | # Convert the bounding box to a square box 27 | size = torch.max(width, height).unsqueeze(-1) 28 | return center, size 29 | 30 | def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.): 31 | batch_size = center.shape[0] 32 | trans_scale = (torch.rand([batch_size, 2], device=center.device)*2. -1.) * trans_scale 33 | center = center + trans_scale*bbox_size # 0.5 34 | scale = torch.rand([batch_size,1], device=center.device) * (scale[1] - scale[0]) + scale[0] 35 | size = bbox_size*scale 36 | return center, size 37 | 38 | def crop_tensor(image, center, bbox_size, crop_size, interpolation = 'bilinear', align_corners=False): 39 | ''' for batch image 40 | Args: 41 | image (torch.Tensor): the reference tensor of shape BXHxWXC. 42 | center: [bz, 2] 43 | bboxsize: [bz, 1] 44 | crop_size; 45 | interpolation (str): Interpolation flag. Default: 'bilinear'. 46 | align_corners (bool): mode for grid_generation. Default: False. See 47 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details 48 | Returns: 49 | cropped_image 50 | tform 51 | ''' 52 | dtype = image.dtype 53 | device = image.device 54 | batch_size = image.shape[0] 55 | # points: top-left, top-right, bottom-right, bottom-left 56 | src_pts = torch.zeros([4,2], dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, -1).contiguous() 57 | 58 | src_pts[:, 0, :] = center - bbox_size*0.5 # / (self.crop_size - 1) 59 | src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5 60 | src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5 61 | src_pts[:, 2, :] = center + bbox_size * 0.5 62 | src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5 63 | src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5 64 | 65 | DST_PTS = torch.tensor([[ 66 | [0, 0], 67 | [crop_size - 1, 0], 68 | [crop_size - 1, crop_size - 1], 69 | [0, crop_size - 1], 70 | ]], dtype=dtype, device=device).expand(batch_size, -1, -1) 71 | # estimate transformation between points 72 | dst_trans_src = get_perspective_transform(src_pts, DST_PTS) 73 | # simulate broadcasting 74 | # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1) 75 | 76 | # warp images 77 | cropped_image = warp_affine( 78 | image, dst_trans_src[:, :2, :], (crop_size, crop_size), 79 | flags=interpolation, align_corners=align_corners) 80 | 81 | tform = torch.transpose(dst_trans_src, 2, 1) 82 | # tform = torch.inverse(dst_trans_src) 83 | return cropped_image, tform 84 | 85 | class Cropper(object): 86 | def __init__(self, crop_size, scale=[1,1], trans_scale = 0.): 87 | self.crop_size = crop_size 88 | self.scale = scale 89 | self.trans_scale = trans_scale 90 | 91 | def crop(self, image, points, points_scale=None): 92 | # points to bbox 93 | center, bbox_size = points2bbox(points.clone(), points_scale) 94 | # argument bbox. TODO: add rotation? 95 | center, bbox_size = augment_bbox(center, bbox_size, scale=self.scale, trans_scale=self.trans_scale) 96 | # crop 97 | cropped_image, tform = crop_tensor(image, center, bbox_size, self.crop_size) 98 | return cropped_image, tform 99 | 100 | def transform_points(self, points, tform, points_scale=None, normalize = True): 101 | points_2d = points[:,:,:2] 102 | 103 | #'input points must use original range' 104 | if points_scale: 105 | assert points_scale[0]==points_scale[1] 106 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0] 107 | 108 | batch_size, n_points, _ = points.shape 109 | trans_points_2d = torch.bmm( 110 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1), 111 | tform 112 | ) 113 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1) 114 | if normalize: 115 | trans_points[:,:,:2] = trans_points[:,:,:2]/self.crop_size*2 - 1 116 | return trans_points 117 | 118 | def transform_points(points, tform, points_scale=None, out_scale=None): 119 | points_2d = points[:,:,:2] 120 | 121 | #'input points must use original range' 122 | if points_scale: 123 | assert points_scale[0]==points_scale[1] 124 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0] 125 | # import ipdb; ipdb.set_trace() 126 | 127 | batch_size, n_points, _ = points.shape 128 | trans_points_2d = torch.bmm( 129 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1), 130 | tform 131 | ) 132 | if out_scale: # h,w of output image size 133 | trans_points_2d[:,:,0] = trans_points_2d[:,:,0]/out_scale[1]*2 - 1 134 | trans_points_2d[:,:,1] = trans_points_2d[:,:,1]/out_scale[0]*2 - 1 135 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1) 136 | return trans_points -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | from datetime import datetime 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from PIL import Image 11 | from diffusers import AutoencoderKL, DDIMScheduler 12 | from diffusers.utils.import_utils import is_xformers_available 13 | from omegaconf import OmegaConf 14 | from torchvision import transforms 15 | from transformers import CLIPTokenizer, CLIPTextModel 16 | 17 | from models.guidance_encoder import GuidanceEncoder 18 | from models.mgportrait_model import MgpModel 19 | from models.mutual_self_attention import ReferenceAttentionControl 20 | from models.unet_2d_condition import UNet2DConditionModel 21 | from models.unet_3d import UNet3DConditionModel 22 | from pipelines.pipeline_aggregation import MultiGuidance2LongVideoPipeline 23 | from utils.video_utils import resize_tensor_frames, save_videos_grid, pil_list_to_tensor 24 | 25 | 26 | def tokenize_captions(tokenizer, captions): 27 | inputs = tokenizer( 28 | captions, 29 | max_length=tokenizer.model_max_length, 30 | padding="max_length", 31 | truncation=True, 32 | return_tensors="pt" 33 | ) 34 | return inputs.input_ids 35 | 36 | 37 | def setup_savedir(cfg): 38 | time_str = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 39 | if cfg.exp_name is None: 40 | savedir = f"results/animations/exp-{time_str}" 41 | else: 42 | savedir = f"results/animations/{cfg.exp_name}-{time_str}" 43 | 44 | os.makedirs(savedir, exist_ok=True) 45 | 46 | return savedir 47 | 48 | 49 | def setup_guidance_encoder(cfg): 50 | guidance_encoder_group = dict() 51 | 52 | if cfg.weight_dtype == "fp16": 53 | weight_dtype = torch.float16 54 | else: 55 | weight_dtype = torch.float32 56 | weight_dtype = torch.float32 57 | # ['depth', 'normal', 'dwpose'] 58 | for guidance_type in cfg.guidance_types: 59 | guidance_encoder_group[guidance_type] = GuidanceEncoder( 60 | guidance_embedding_channels=cfg.guidance_encoder_kwargs.guidance_embedding_channels, 61 | guidance_input_channels=cfg.guidance_encoder_kwargs.guidance_input_channels, 62 | block_out_channels=cfg.guidance_encoder_kwargs.block_out_channels, 63 | ).to(device="cuda", dtype=weight_dtype) 64 | 65 | return guidance_encoder_group 66 | 67 | 68 | 69 | def combine_guidance_data(cfg): 70 | guidance_types = cfg.guidance_types 71 | guidance_data_folder = cfg.data.guidance_data_folder 72 | 73 | guidance_pil_group = dict() 74 | for guidance_type in guidance_types: 75 | guidance_pil_group[guidance_type] = [] 76 | guidance_image_lst = sorted( 77 | Path(osp.join(guidance_data_folder, guidance_type)).iterdir() 78 | ) 79 | guidance_image_lst = ( 80 | guidance_image_lst 81 | if not cfg.data.frame_range 82 | else guidance_image_lst[cfg.data.frame_range[0]:cfg.data.frame_range[1]] 83 | ) 84 | 85 | for guidance_image_path in guidance_image_lst: 86 | guidance_pil_group[guidance_type] += [ 87 | Image.open(guidance_image_path).convert("RGB") 88 | ] 89 | 90 | first_guidance_length = len(list(guidance_pil_group.values())[0]) 91 | assert all( 92 | len(sublist) == first_guidance_length 93 | for sublist in list(guidance_pil_group.values()) 94 | ) 95 | 96 | return guidance_pil_group, first_guidance_length 97 | 98 | 99 | 100 | def inference( 101 | cfg, 102 | vae, 103 | text_encoder, 104 | tokenizer, 105 | model, 106 | scheduler, 107 | ref_image_pil, 108 | guidance_pil_group, 109 | video_length, 110 | width, 111 | height, 112 | device="cuda", 113 | dtype=torch.float16, 114 | ): 115 | reference_unet = model.reference_unet 116 | denoising_unet = model.denoising_unet 117 | guidance_types = cfg.guidance_types 118 | guidance_encoder_group = { 119 | f"guidance_encoder_{g}": getattr(model, f"guidance_encoder_{g}") 120 | for g in guidance_types 121 | } 122 | 123 | generator = torch.Generator(device=device) 124 | generator.manual_seed(cfg.seed) 125 | pipeline = MultiGuidance2LongVideoPipeline( 126 | vae=vae, 127 | reference_unet=reference_unet, 128 | denoising_unet=denoising_unet, 129 | **guidance_encoder_group, 130 | scheduler=scheduler, 131 | guidance_process_size=cfg.data.get("guidance_process_size", None) 132 | ) 133 | pipeline = pipeline.to(device, dtype) 134 | 135 | prompt = "A close up of a person." 136 | prompt_embeds = text_encoder(tokenize_captions(tokenizer, [prompt]).to('cuda'))[0] 137 | 138 | video = pipeline( 139 | ref_image_pil, 140 | prompt_embeds, 141 | guidance_pil_group, 142 | width, 143 | height, 144 | video_length, 145 | num_inference_steps=cfg.num_inference_steps, 146 | guidance_scale=cfg.guidance_scale, 147 | generator=generator, 148 | ).videos 149 | 150 | del pipeline 151 | torch.cuda.empty_cache() 152 | 153 | return video 154 | 155 | 156 | def main(cfg): 157 | logging.basicConfig( 158 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 159 | datefmt="%m/%d/%Y %H:%M:%S", 160 | level=logging.INFO, 161 | ) 162 | 163 | save_dir = setup_savedir(cfg) 164 | logging.info(f"Running inference ...") 165 | weight_dtype = torch.float32 166 | 167 | sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) 168 | # {'num_train_timesteps': 1000, 'beta_start': 0.00085, 'beta_end': 0.012, 169 | # 'beta_schedule': 'linear', 'steps_offset': 1, 'clip_sample': False} 170 | if cfg.enable_zero_snr: 171 | sched_kwargs.update( 172 | rescale_betas_zero_snr=True, 173 | timestep_spacing="trailing", 174 | prediction_type="v_prediction", 175 | ) 176 | noise_scheduler = DDIMScheduler(**sched_kwargs) 177 | sched_kwargs.update({"beta_schedule": "scaled_linear"}) 178 | 179 | vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( 180 | dtype=weight_dtype, device="cuda" 181 | ) 182 | 183 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 184 | cfg.base_model_path, 185 | cfg.motion_module_path, 186 | subfolder="unet", 187 | unet_additional_kwargs=cfg.unet_additional_kwargs, 188 | ).to(dtype=weight_dtype, device="cuda") 189 | 190 | reference_unet = UNet2DConditionModel.from_pretrained( 191 | cfg.base_model_path, 192 | subfolder="unet", 193 | ).to(device="cuda", dtype=weight_dtype) 194 | 195 | text_encoder = CLIPTextModel.from_pretrained( 196 | cfg.base_model_path, 197 | subfolder="text_encoder", 198 | ).to(device="cuda") 199 | 200 | tokenizer = CLIPTokenizer.from_pretrained( 201 | cfg.base_model_path, 202 | subfolder="tokenizer", 203 | ) 204 | 205 | guidance_encoder_group = setup_guidance_encoder(cfg) 206 | 207 | 208 | ckpt_dir = cfg.ckpt_dir 209 | denoising_unet.load_state_dict( 210 | torch.load( 211 | osp.join(ckpt_dir, f"denoising_unet.pth"), 212 | map_location="cpu", 213 | ), 214 | strict=False, 215 | ) 216 | reference_unet.load_state_dict( 217 | torch.load( 218 | osp.join(ckpt_dir, f"reference_unet.pth"), 219 | map_location="cpu", 220 | ), 221 | strict=False, 222 | ) 223 | 224 | for guidance_type, guidance_encoder_module in guidance_encoder_group.items(): 225 | guidance_encoder_module.load_state_dict( 226 | torch.load( 227 | osp.join(ckpt_dir, f"guidance_encoder_{guidance_type}.pth"), 228 | map_location="cpu", 229 | ), 230 | strict=False, 231 | ) 232 | 233 | reference_control_writer = ReferenceAttentionControl( 234 | reference_unet, 235 | do_classifier_free_guidance=False, 236 | mode="write", 237 | fusion_blocks="full", 238 | ) 239 | reference_control_reader = ReferenceAttentionControl( 240 | denoising_unet, 241 | do_classifier_free_guidance=False, 242 | mode="read", 243 | fusion_blocks="full", 244 | ) 245 | 246 | model = MgpModel( 247 | reference_unet=reference_unet, 248 | denoising_unet=denoising_unet, 249 | reference_control_writer=reference_control_writer, 250 | reference_control_reader=reference_control_reader, 251 | guidance_encoder_group=guidance_encoder_group, 252 | ).to("cuda", dtype=weight_dtype) 253 | 254 | if cfg.enable_xformers_memory_efficient_attention: 255 | if is_xformers_available(): 256 | reference_unet.enable_xformers_memory_efficient_attention() 257 | denoising_unet.enable_xformers_memory_efficient_attention() 258 | else: 259 | raise ValueError( 260 | "xformers is not available. Make sure it is installed correctly" 261 | ) 262 | 263 | ref_image_path = cfg.data.ref_image_path 264 | ref_image_pil = Image.open(ref_image_path) 265 | ref_image_w, ref_image_h = ref_image_pil.size 266 | 267 | 268 | guidance_pil_group, video_length = combine_guidance_data(cfg) 269 | 270 | result_video_tensor = inference( 271 | cfg=cfg, 272 | vae=vae, 273 | text_encoder=text_encoder, 274 | tokenizer=tokenizer, 275 | model=model, 276 | scheduler=noise_scheduler, 277 | ref_image_pil=ref_image_pil, 278 | guidance_pil_group=guidance_pil_group, 279 | video_length=video_length, 280 | width=cfg.width, 281 | height=cfg.height, 282 | device="cuda", 283 | dtype=weight_dtype, 284 | ) # (1, c, f, h, w) 285 | 286 | result_video_tensor = resize_tensor_frames( 287 | result_video_tensor, (ref_image_h, ref_image_w) 288 | ) 289 | save_videos_grid(result_video_tensor, osp.join(save_dir, "animation.mp4")) 290 | 291 | ref_video_tensor = transforms.ToTensor()(ref_image_pil)[None, :, None, ...].repeat( 292 | 1, 1, video_length, 1, 1 293 | ) 294 | guidance_video_tensor_lst = [] 295 | for guidance_pil_lst in guidance_pil_group.values(): 296 | guidance_video_tensor_lst += [ 297 | pil_list_to_tensor(guidance_pil_lst, size=(ref_image_h, ref_image_w)) 298 | ] 299 | guidance_video_tensor = torch.stack(guidance_video_tensor_lst, dim=0) 300 | 301 | grid_video = torch.cat([ref_video_tensor, result_video_tensor], dim=0) 302 | grid_video_wguidance = torch.cat( 303 | [ref_video_tensor, result_video_tensor, guidance_video_tensor], dim=0 304 | ) 305 | 306 | save_videos_grid(grid_video, osp.join(save_dir, "grid.mp4")) 307 | save_videos_grid(grid_video_wguidance, osp.join(save_dir, "guidance.mp4")) 308 | 309 | logging.info(f"Inference completed, results saved in {save_dir}") 310 | 311 | 312 | if __name__ == "__main__": 313 | 314 | parser = argparse.ArgumentParser() 315 | parser.add_argument("--config", type=str, default="./configs/inference/inference.yaml") 316 | args = parser.parse_args() 317 | 318 | if args.config[-5:] == ".yaml": 319 | cfg = OmegaConf.load(args.config) 320 | else: 321 | raise ValueError("Do not support this format config file") 322 | 323 | main(cfg) 324 | -------------------------------------------------------------------------------- /models/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/exp_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/exp_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/guidance_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/guidance_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/mgportrait_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/mgportrait_model.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/motion_module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/motion_module.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/mutual_self_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/mutual_self_attention.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer_2d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/transformer_2d.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/transformer_3d.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_2d_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/unet_2d_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_2d_condition.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/unet_2d_condition.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/unet_3d.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet_3d_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/unet_3d_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /models/exp_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | 7 | from diffusers.models.modeling_utils import ModelMixin 8 | from diffusers.utils import BaseOutput 9 | from dataclasses import dataclass 10 | 11 | 12 | 13 | 14 | class ExpEncoder(ModelMixin): 15 | def __init__( 16 | self, 17 | input_size: int = 64, 18 | hidden_sizes=None, 19 | output_size: int = 768, 20 | 21 | ): 22 | super().__init__() 23 | if hidden_sizes is None: 24 | hidden_sizes = [128, 256, 512] 25 | self.layers = nn.ModuleList() 26 | self.layers.append(nn.Linear(input_size, hidden_sizes[0])) 27 | for i in range(1, len(hidden_sizes)): 28 | self.layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) 29 | self.layers.append(nn.Linear(hidden_sizes[-1], output_size)) 30 | 31 | def forward(self, x): 32 | for layer in self.layers[:-1]: 33 | x = F.relu(layer(x)) 34 | # 输出层不加激活函数 35 | x = self.layers[-1](x) 36 | 37 | return x 38 | -------------------------------------------------------------------------------- /models/guidance_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | 7 | from diffusers.models.modeling_utils import ModelMixin 8 | from diffusers.utils import BaseOutput 9 | from dataclasses import dataclass 10 | from transformers import CLIPVisionModelWithProjection 11 | 12 | from models.motion_module import zero_module 13 | from models.resnet import InflatedConv3d, InflatedGroupNorm 14 | from models.attention import TemporalBasicTransformerBlock 15 | from models.transformer_3d import Transformer3DModel 16 | 17 | 18 | class GuidanceEncoder(ModelMixin): 19 | def __init__( 20 | self, 21 | guidance_embedding_channels: int, 22 | guidance_input_channels: int = 3, 23 | block_out_channels: Tuple[int] = (16, 32, 96, 256), 24 | attention_num_heads: int = 8, 25 | ): 26 | super().__init__() 27 | self.guidance_input_channels = guidance_input_channels 28 | self.conv_in = InflatedConv3d( 29 | guidance_input_channels, block_out_channels[0], kernel_size=3, padding=1 30 | ) 31 | 32 | self.blocks = nn.ModuleList([]) 33 | self.attentions = nn.ModuleList([]) 34 | 35 | for i in range(len(block_out_channels) - 1): 36 | channel_in = block_out_channels[i] 37 | channel_out = block_out_channels[i + 1] 38 | 39 | self.blocks.append( 40 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) 41 | ) 42 | self.attentions.append( 43 | Transformer3DModel( 44 | attention_num_heads, 45 | channel_in // attention_num_heads, 46 | channel_in, 47 | norm_num_groups=1, 48 | unet_use_cross_frame_attention=False, 49 | unet_use_temporal_attention=False, 50 | ) 51 | ) 52 | 53 | self.blocks.append( 54 | InflatedConv3d( 55 | channel_in, channel_out, kernel_size=3, padding=1, stride=2 56 | ) 57 | ) 58 | self.attentions.append( 59 | Transformer3DModel( 60 | attention_num_heads, 61 | channel_out // attention_num_heads, 62 | channel_out, 63 | norm_num_groups=32, 64 | unet_use_cross_frame_attention=False, 65 | unet_use_temporal_attention=False, 66 | ) 67 | ) 68 | 69 | attention_channel_out = block_out_channels[-1] 70 | self.guidance_attention = Transformer3DModel( 71 | attention_num_heads, 72 | attention_channel_out // attention_num_heads, 73 | attention_channel_out, 74 | norm_num_groups=32, 75 | unet_use_cross_frame_attention=False, 76 | unet_use_temporal_attention=False, 77 | ) 78 | 79 | self.conv_out = zero_module( 80 | InflatedConv3d( 81 | block_out_channels[-1], 82 | guidance_embedding_channels, 83 | kernel_size=3, 84 | padding=1, 85 | ) 86 | ) 87 | 88 | def forward(self, condition): 89 | 90 | embedding = self.conv_in(condition) 91 | embedding = F.silu(embedding) 92 | 93 | for block in self.blocks: 94 | embedding = block(embedding) 95 | embedding = F.silu(embedding) 96 | 97 | embedding = self.attentions[-1](embedding).sample 98 | embedding = self.conv_out(embedding) 99 | return embedding 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /models/mgportrait_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.unet_2d_condition import UNet2DConditionModel 4 | from models.unet_3d import UNet3DConditionModel 5 | from models.exp_encoder import ExpEncoder 6 | from einops import rearrange 7 | 8 | class MgpModel(nn.Module): 9 | def __init__( 10 | self, 11 | reference_unet: UNet2DConditionModel, 12 | denoising_unet: UNet3DConditionModel, 13 | reference_control_writer, 14 | reference_control_reader, 15 | guidance_encoder_group, 16 | ): 17 | super().__init__() 18 | self.reference_unet = reference_unet 19 | self.denoising_unet = denoising_unet 20 | 21 | self.reference_control_writer = reference_control_writer 22 | self.reference_control_reader = reference_control_reader 23 | 24 | self.guidance_types = [] 25 | self.guidance_input_channels = [] 26 | 27 | for guidance_type, guidance_module in guidance_encoder_group.items(): 28 | setattr(self, f"guidance_encoder_{guidance_type}", guidance_module) 29 | self.guidance_types.append(guidance_type) 30 | self.guidance_input_channels.append(guidance_module.guidance_input_channels) 31 | 32 | def forward( 33 | self, 34 | noisy_latents, 35 | timesteps, 36 | ref_image_latents, 37 | multi_guidance_cond, 38 | text_embeds, 39 | uncond_fwd: bool = False, 40 | ): 41 | guidance_cond_group = torch.split( 42 | multi_guidance_cond, self.guidance_input_channels, dim=1 43 | ) 44 | guidance_fea_lst = [] 45 | for guidance_idx, guidance_cond in enumerate(guidance_cond_group): 46 | guidance_encoder = getattr( 47 | self, f"guidance_encoder_{self.guidance_types[guidance_idx]}" 48 | ) 49 | guidance_fea = guidance_encoder(guidance_cond) 50 | guidance_fea_lst += [guidance_fea] 51 | guidance_fea = torch.stack(guidance_fea_lst, dim=0).sum(0) 52 | 53 | # video_length = exp_embedding.shape[1] 54 | # exp_embedding = rearrange(exp_embedding, "b f d -> (b f) d") 55 | # exp_embed = self.exp_encoder(exp_embedding) 56 | # exp_embed = rearrange(exp_embed, "(b f) d -> b f d", f=video_length) 57 | 58 | 59 | if not uncond_fwd: 60 | ref_timesteps = torch.zeros_like(timesteps) 61 | self.reference_unet( 62 | ref_image_latents, 63 | ref_timesteps, 64 | encoder_hidden_states=text_embeds, # [4, 77, 768] 65 | return_dict=False, 66 | ) 67 | 68 | self.reference_control_reader.update(self.reference_control_writer) 69 | 70 | model_pred = self.denoising_unet( 71 | noisy_latents, 72 | timesteps, 73 | guidance_fea=guidance_fea, 74 | encoder_hidden_states=text_embeds, # [btz, frames, dim] 75 | ).sample 76 | 77 | return model_pred 78 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | class InflatedConv3d(nn.Conv2d): 10 | def forward(self, x): 11 | video_length = x.shape[2] 12 | 13 | x = rearrange(x, "b c f h w -> (b f) c h w") 14 | x = super().forward(x) 15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 16 | 17 | return x 18 | 19 | 20 | class InflatedGroupNorm(nn.GroupNorm): 21 | def forward(self, x): 22 | video_length = x.shape[2] 23 | 24 | x = rearrange(x, "b c f h w -> (b f) c h w") 25 | x = super().forward(x) 26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 27 | 28 | return x 29 | 30 | 31 | class Upsample3D(nn.Module): 32 | def __init__( 33 | self, 34 | channels, 35 | use_conv=False, 36 | use_conv_transpose=False, 37 | out_channels=None, 38 | name="conv", 39 | ): 40 | super().__init__() 41 | self.channels = channels 42 | self.out_channels = out_channels or channels 43 | self.use_conv = use_conv 44 | self.use_conv_transpose = use_conv_transpose 45 | self.name = name 46 | 47 | conv = None 48 | if use_conv_transpose: 49 | raise NotImplementedError 50 | elif use_conv: 51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 52 | 53 | def forward(self, hidden_states, output_size=None): 54 | assert hidden_states.shape[1] == self.channels 55 | 56 | if self.use_conv_transpose: 57 | raise NotImplementedError 58 | 59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 60 | dtype = hidden_states.dtype 61 | if dtype == torch.bfloat16: 62 | hidden_states = hidden_states.to(torch.float32) 63 | 64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 65 | if hidden_states.shape[0] >= 64: 66 | hidden_states = hidden_states.contiguous() 67 | 68 | # if `output_size` is passed we force the interpolation output 69 | # size and do not make use of `scale_factor=2` 70 | if output_size is None: 71 | hidden_states = F.interpolate( 72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 73 | ) 74 | else: 75 | hidden_states = F.interpolate( 76 | hidden_states, size=output_size, mode="nearest" 77 | ) 78 | 79 | # If the input is bfloat16, we cast back to bfloat16 80 | if dtype == torch.bfloat16: 81 | hidden_states = hidden_states.to(dtype) 82 | 83 | # if self.use_conv: 84 | # if self.name == "conv": 85 | # hidden_states = self.conv(hidden_states) 86 | # else: 87 | # hidden_states = self.Conv2d_0(hidden_states) 88 | hidden_states = self.conv(hidden_states) 89 | 90 | return hidden_states 91 | 92 | 93 | class Downsample3D(nn.Module): 94 | def __init__( 95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 96 | ): 97 | super().__init__() 98 | self.channels = channels 99 | self.out_channels = out_channels or channels 100 | self.use_conv = use_conv 101 | self.padding = padding 102 | stride = 2 103 | self.name = name 104 | 105 | if use_conv: 106 | self.conv = InflatedConv3d( 107 | self.channels, self.out_channels, 3, stride=stride, padding=padding 108 | ) 109 | else: 110 | raise NotImplementedError 111 | 112 | def forward(self, hidden_states): 113 | assert hidden_states.shape[1] == self.channels 114 | if self.use_conv and self.padding == 0: 115 | raise NotImplementedError 116 | 117 | assert hidden_states.shape[1] == self.channels 118 | hidden_states = self.conv(hidden_states) 119 | 120 | return hidden_states 121 | 122 | 123 | class ResnetBlock3D(nn.Module): 124 | def __init__( 125 | self, 126 | *, 127 | in_channels, 128 | out_channels=None, 129 | conv_shortcut=False, 130 | dropout=0.0, 131 | temb_channels=512, 132 | groups=32, 133 | groups_out=None, 134 | pre_norm=True, 135 | eps=1e-6, 136 | non_linearity="swish", 137 | time_embedding_norm="default", 138 | output_scale_factor=1.0, 139 | use_in_shortcut=None, 140 | use_inflated_groupnorm=None, 141 | ): 142 | super().__init__() 143 | self.pre_norm = pre_norm 144 | self.pre_norm = True 145 | self.in_channels = in_channels 146 | out_channels = in_channels if out_channels is None else out_channels 147 | self.out_channels = out_channels 148 | self.use_conv_shortcut = conv_shortcut 149 | self.time_embedding_norm = time_embedding_norm 150 | self.output_scale_factor = output_scale_factor 151 | 152 | if groups_out is None: 153 | groups_out = groups 154 | 155 | assert use_inflated_groupnorm != None 156 | if use_inflated_groupnorm: 157 | self.norm1 = InflatedGroupNorm( 158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 159 | ) 160 | else: 161 | self.norm1 = torch.nn.GroupNorm( 162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 163 | ) 164 | 165 | self.conv1 = InflatedConv3d( 166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 167 | ) 168 | 169 | if temb_channels is not None: 170 | if self.time_embedding_norm == "default": 171 | time_emb_proj_out_channels = out_channels 172 | elif self.time_embedding_norm == "scale_shift": 173 | time_emb_proj_out_channels = out_channels * 2 174 | else: 175 | raise ValueError( 176 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 177 | ) 178 | 179 | self.time_emb_proj = torch.nn.Linear( 180 | temb_channels, time_emb_proj_out_channels 181 | ) 182 | else: 183 | self.time_emb_proj = None 184 | 185 | if use_inflated_groupnorm: 186 | self.norm2 = InflatedGroupNorm( 187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 188 | ) 189 | else: 190 | self.norm2 = torch.nn.GroupNorm( 191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 192 | ) 193 | self.dropout = torch.nn.Dropout(dropout) 194 | self.conv2 = InflatedConv3d( 195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 196 | ) 197 | 198 | if non_linearity == "swish": 199 | self.nonlinearity = lambda x: F.silu(x) 200 | elif non_linearity == "mish": 201 | self.nonlinearity = Mish() 202 | elif non_linearity == "silu": 203 | self.nonlinearity = nn.SiLU() 204 | 205 | self.use_in_shortcut = ( 206 | self.in_channels != self.out_channels 207 | if use_in_shortcut is None 208 | else use_in_shortcut 209 | ) 210 | 211 | self.conv_shortcut = None 212 | if self.use_in_shortcut: 213 | self.conv_shortcut = InflatedConv3d( 214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 215 | ) 216 | 217 | def forward(self, input_tensor, temb): 218 | hidden_states = input_tensor 219 | 220 | hidden_states = self.norm1(hidden_states) 221 | hidden_states = self.nonlinearity(hidden_states) 222 | 223 | hidden_states = self.conv1(hidden_states) 224 | 225 | if temb is not None: 226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 227 | 228 | if temb is not None and self.time_embedding_norm == "default": 229 | hidden_states = hidden_states + temb 230 | 231 | hidden_states = self.norm2(hidden_states) 232 | 233 | if temb is not None and self.time_embedding_norm == "scale_shift": 234 | scale, shift = torch.chunk(temb, 2, dim=1) 235 | hidden_states = hidden_states * (1 + scale) + shift 236 | 237 | hidden_states = self.nonlinearity(hidden_states) 238 | 239 | hidden_states = self.dropout(hidden_states) 240 | hidden_states = self.conv2(hidden_states) 241 | 242 | if self.conv_shortcut is not None: 243 | input_tensor = self.conv_shortcut(input_tensor) 244 | 245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 246 | 247 | return output_tensor 248 | 249 | 250 | class Mish(torch.nn.Module): 251 | def forward(self, hidden_states): 252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 253 | -------------------------------------------------------------------------------- /models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.models import ModelMixin 7 | from diffusers.utils import BaseOutput 8 | from diffusers.utils.import_utils import is_xformers_available 9 | from einops import rearrange, repeat 10 | from torch import nn 11 | 12 | from .attention import TemporalBasicTransformerBlock 13 | 14 | 15 | @dataclass 16 | class Transformer3DModelOutput(BaseOutput): 17 | sample: torch.FloatTensor 18 | 19 | 20 | if is_xformers_available(): 21 | import xformers 22 | import xformers.ops 23 | else: 24 | xformers = None 25 | 26 | 27 | class Transformer3DModel(ModelMixin, ConfigMixin): 28 | _supports_gradient_checkpointing = True 29 | 30 | @register_to_config 31 | def __init__( 32 | self, 33 | num_attention_heads: int = 16, 34 | attention_head_dim: int = 88, 35 | in_channels: Optional[int] = None, 36 | num_layers: int = 1, 37 | dropout: float = 0.0, 38 | norm_num_groups: int = 32, 39 | cross_attention_dim: Optional[int] = None, 40 | attention_bias: bool = False, 41 | activation_fn: str = "geglu", 42 | num_embeds_ada_norm: Optional[int] = None, 43 | use_linear_projection: bool = False, 44 | only_cross_attention: bool = False, 45 | upcast_attention: bool = False, 46 | unet_use_cross_frame_attention=None, 47 | unet_use_temporal_attention=None, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | # Define input layers 56 | self.in_channels = in_channels 57 | 58 | self.norm = torch.nn.GroupNorm( 59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 60 | ) 61 | if use_linear_projection: 62 | self.proj_in = nn.Linear(in_channels, inner_dim) 63 | else: 64 | self.proj_in = nn.Conv2d( 65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 66 | ) 67 | 68 | # Define transformers blocks 69 | self.transformer_blocks = nn.ModuleList( 70 | [ 71 | TemporalBasicTransformerBlock( 72 | inner_dim, 73 | num_attention_heads, 74 | attention_head_dim, 75 | dropout=dropout, 76 | cross_attention_dim=cross_attention_dim, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | attention_bias=attention_bias, 80 | only_cross_attention=only_cross_attention, 81 | upcast_attention=upcast_attention, 82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 83 | unet_use_temporal_attention=unet_use_temporal_attention, 84 | ) 85 | for d in range(num_layers) 86 | ] 87 | ) 88 | 89 | # 4. Define output layers 90 | if use_linear_projection: 91 | self.proj_out = nn.Linear(in_channels, inner_dim) 92 | else: 93 | self.proj_out = nn.Conv2d( 94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 95 | ) 96 | 97 | self.gradient_checkpointing = False 98 | 99 | def _set_gradient_checkpointing(self, module, value=False): 100 | if hasattr(module, "gradient_checkpointing"): 101 | module.gradient_checkpointing = value 102 | 103 | def forward( 104 | self, 105 | hidden_states, 106 | encoder_hidden_states=None, 107 | timestep=None, 108 | return_dict: bool = True, 109 | ): 110 | # Input 111 | assert ( 112 | hidden_states.dim() == 5 113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 114 | video_length = hidden_states.shape[2] 115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 116 | if encoder_hidden_states is not None: 117 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 118 | encoder_hidden_states = repeat( 119 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 120 | ) 121 | 122 | batch, channel, height, weight = hidden_states.shape 123 | residual = hidden_states 124 | 125 | hidden_states = self.norm(hidden_states) 126 | if not self.use_linear_projection: 127 | hidden_states = self.proj_in(hidden_states) 128 | inner_dim = hidden_states.shape[1] 129 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 130 | batch, height * weight, inner_dim 131 | ) 132 | else: 133 | inner_dim = hidden_states.shape[1] 134 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 135 | batch, height * weight, inner_dim 136 | ) 137 | hidden_states = self.proj_in(hidden_states) 138 | 139 | # Blocks 140 | for i, block in enumerate(self.transformer_blocks): 141 | hidden_states = block( 142 | hidden_states, 143 | encoder_hidden_states=encoder_hidden_states, 144 | timestep=timestep, 145 | video_length=video_length, 146 | ) 147 | 148 | # Output 149 | if not self.use_linear_projection: 150 | hidden_states = ( 151 | hidden_states.reshape(batch, height, weight, inner_dim) 152 | .permute(0, 3, 1, 2) 153 | .contiguous() 154 | ) 155 | hidden_states = self.proj_out(hidden_states) 156 | else: 157 | hidden_states = self.proj_out(hidden_states) 158 | hidden_states = ( 159 | hidden_states.reshape(batch, height, weight, inner_dim) 160 | .permute(0, 3, 1, 2) 161 | .contiguous() 162 | ) 163 | 164 | output = hidden_states + residual 165 | 166 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 167 | if not return_dict: 168 | return (output,) 169 | 170 | return Transformer3DModelOutput(sample=output) 171 | -------------------------------------------------------------------------------- /pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__init__.py -------------------------------------------------------------------------------- /pipelines/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/context.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__pycache__/context.cpython-310.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/pipe_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__pycache__/pipe_utils.cpython-310.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/pipeline_aggregation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__pycache__/pipeline_aggregation.cpython-310.pyc -------------------------------------------------------------------------------- /pipelines/context.py: -------------------------------------------------------------------------------- 1 | # TODO: Adapted from cli 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | 6 | 7 | def ordered_halving(val): 8 | bin_str = f"{val:064b}" 9 | bin_flip = bin_str[::-1] 10 | as_int = int(bin_flip, 2) 11 | 12 | return as_int / (1 << 64) 13 | 14 | 15 | def uniform( 16 | step: int = ..., 17 | num_steps: Optional[int] = None, 18 | num_frames: int = ..., 19 | context_size: Optional[int] = None, 20 | context_stride: int = 3, 21 | context_overlap: int = 4, 22 | closed_loop: bool = True, 23 | ): 24 | if num_frames <= context_size: 25 | yield list(range(num_frames)) 26 | return 27 | 28 | context_stride = min( 29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 30 | ) 31 | 32 | for context_step in 1 << np.arange(context_stride): 33 | pad = int(round(num_frames * ordered_halving(step))) 34 | for j in range( 35 | int(ordered_halving(step) * context_step) + pad, 36 | num_frames + pad + (0 if closed_loop else -context_overlap), 37 | (context_size * context_step - context_overlap), 38 | ): 39 | yield [ 40 | e % num_frames 41 | for e in range(j, j + context_size * context_step, context_step) 42 | ] 43 | 44 | 45 | def get_context_scheduler(name: str) -> Callable: 46 | if name == "uniform": 47 | return uniform 48 | else: 49 | raise ValueError(f"Unknown context_overlap policy {name}") 50 | 51 | 52 | def get_total_steps( 53 | scheduler, 54 | timesteps: List[int], 55 | num_steps: Optional[int] = None, 56 | num_frames: int = ..., 57 | context_size: Optional[int] = None, 58 | context_stride: int = 3, 59 | context_overlap: int = 4, 60 | closed_loop: bool = True, 61 | ): 62 | return sum( 63 | len( 64 | list( 65 | scheduler( 66 | i, 67 | num_steps, 68 | num_frames, 69 | context_size, 70 | context_stride, 71 | context_overlap, 72 | ) 73 | ) 74 | ) 75 | for i in range(len(timesteps)) 76 | ) 77 | -------------------------------------------------------------------------------- /pipelines/pipe_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | tensor_interpolation = None 4 | 5 | 6 | def get_tensor_interpolation_method(): 7 | return tensor_interpolation 8 | 9 | 10 | def set_tensor_interpolation_method(is_slerp): 11 | global tensor_interpolation 12 | tensor_interpolation = slerp if is_slerp else linear 13 | 14 | 15 | def linear(v1, v2, t): 16 | return (1.0 - t) * v1 + t * v2 17 | 18 | 19 | def slerp( 20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 21 | ) -> torch.Tensor: 22 | u0 = v0 / v0.norm() 23 | u1 = v1 / v1.norm() 24 | dot = (u0 * u1).sum() 25 | if dot.abs() > DOT_THRESHOLD: 26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') 27 | return (1.0 - t) * v0 + t * v1 28 | omega = dot.acos() 29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() 30 | -------------------------------------------------------------------------------- /render_and_transfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch as th 5 | from torchvision.utils import save_image 6 | import argparse 7 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "."))) 8 | 9 | 10 | from decalib.deca_with_smirk import DECA 11 | from decalib.utils.config import cfg as deca_cfg 12 | from data_utils.transfer_utils import get_image_dict 13 | 14 | # Build DECA 15 | deca_cfg.model.use_tex = True 16 | deca_cfg.model.tex_path = "./decalib/data/FLAME_texture.npz" 17 | deca_cfg.model.tex_type = "FLAME" 18 | deca = DECA(config=deca_cfg, device="cuda") 19 | 20 | 21 | 22 | def get_render(source, target, save_file): 23 | 24 | src_dict = get_image_dict(source, 512, True) 25 | tar_dict = get_image_dict(target, 512, True) 26 | # ===================get DECA codes of the target image=============================== 27 | tar_cropped = tar_dict["image"].unsqueeze(0).to("cuda") 28 | imgname = tar_dict["imagename"] 29 | 30 | with th.no_grad(): 31 | tar_code = deca.encode(tar_cropped) 32 | tar_image = tar_dict["original_image"].unsqueeze(0).to("cuda") 33 | # ===================get DECA codes of the source image=============================== 34 | src_cropped = src_dict["image"].unsqueeze(0).to("cuda") 35 | with th.no_grad(): 36 | src_code = deca.encode(src_cropped) 37 | # To align the face when the pose is changing 38 | src_ffhq_center = deca.decode(src_code, return_ffhq_center=True) 39 | tar_ffhq_center = deca.decode(tar_code, return_ffhq_center=True) 40 | 41 | src_tform = src_dict["tform"].unsqueeze(0) 42 | src_tform = th.inverse(src_tform).transpose(1, 2).to("cuda") 43 | src_code["tform"] = src_tform 44 | 45 | tar_tform = tar_dict["tform"].unsqueeze(0) 46 | tar_tform = th.inverse(tar_tform).transpose(1, 2).to("cuda") 47 | tar_code["tform"] = tar_tform 48 | 49 | src_image = src_dict["original_image"].unsqueeze(0).to("cuda") # 平均的参数 50 | tar_image = tar_dict["original_image"].unsqueeze(0).to("cuda") 51 | 52 | # code 1 means source code, code 2 means target code 53 | code1, code2 = {}, {} 54 | for k in src_code: 55 | code1[k] = src_code[k].clone() 56 | 57 | for k in tar_code: 58 | code2[k] = tar_code[k].clone() 59 | 60 | code1["pose"][:, :3] = code2["pose"][:, :3] 61 | code1['exp'] = code2['exp'] 62 | code1['pose'][:, 3:] = tar_code['pose'][:, 3:] 63 | 64 | opdict, _ = deca.decode( 65 | code1, 66 | render_orig=True, 67 | original_image=tar_image, 68 | tform=src_code["tform"], 69 | align_ffhq=False, 70 | ffhq_center=src_ffhq_center, 71 | imgpath=target 72 | ) 73 | 74 | depth = opdict["depth_images"].detach() 75 | normal = opdict["normal_images"].detach() 76 | render = opdict["rendered_images"].detach() 77 | os.makedirs(f'./transfers/{save_file}/depth', exist_ok=True) 78 | os.makedirs(f'./transfers/{save_file}/normal', exist_ok=True) 79 | os.makedirs(f'./transfers/{save_file}/render', exist_ok=True) 80 | 81 | save_image(depth[0], f"./transfers/{save_file}/depth/{imgname}") 82 | save_image(normal[0], f"./transfers/{save_file}/normal/{imgname}") 83 | save_image(render[0], f"./transfers/{save_file}/render/{imgname}") 84 | 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument( 89 | "--sor_img", 90 | type=str, 91 | default='/home/mengting/projects/diffusionRig/myscripts/papers/example42/boy2_cropped.jpg', 92 | required=False 93 | ) 94 | parser.add_argument( 95 | "--driving_path", 96 | type=str, 97 | default='/home/mengting/Desktop/frames_1500_updated/1fsFQ2gF4oE_0/images', 98 | required=False 99 | ) 100 | parser.add_argument( 101 | "--save_name", 102 | type=str, 103 | default='example1', 104 | required=False 105 | ) 106 | 107 | args = parser.parse_args() 108 | images = sorted(os.listdir(args.driving_path)) 109 | for image in images: 110 | cur_image_path = os.path.join(args.driving_path, image) 111 | get_render(args.sor_img, cur_image_path, args.save_name) 112 | 113 | print('done') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | diffusers==0.32.2 3 | einops==0.8.1 4 | face_alignment==1.4.1 5 | facenet_pytorch==2.6.0 6 | imageio==2.37.0 7 | insightface==0.7.3 8 | ipdb==0.13.13 9 | kornia==0.8.0 10 | loguru==0.7.3 11 | mediapipe==0.10.21 12 | moviepy==2.1.2 13 | numpy==2.2.5 14 | omegaconf==2.3.0 15 | opencv_contrib_python==4.11.0.86 16 | opencv_python==4.11.0.86 17 | opencv_python_headless==4.11.0.86 18 | Pillow==11.2.1 19 | pyarrow==19.0.0 20 | pytorch3d==0.7.8 21 | PyYAML==6.0.2 22 | PyYAML==6.0.2 23 | safetensors==0.5.3 24 | scipy==1.15.2 25 | setuptools==60.2.0 26 | setuptools==75.8.0 27 | skimage==0.0 28 | timm==0.6.13 29 | torchfile==0.1.0 30 | tqdm==4.67.1 31 | transformers==4.49.0 32 | xformers==0.0.28 33 | yacs==0.1.8 34 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/video_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/utils/__pycache__/video_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/download.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import logging 3 | import os 4 | from pathlib import Path 5 | import urllib.request 6 | import tqdm 7 | 8 | 9 | def download(url: str, output: Path) -> None: 10 | if not os.path.exists(output.parent): 11 | os.makedirs(output.parent, exist_ok=True) 12 | 13 | response = urllib.request.urlopen(url) 14 | content_length = response.info().get("Content-Length") 15 | if content_length is None: 16 | raise ValueError("invalid content length") 17 | content_length = int(content_length) 18 | 19 | if os.path.exists(output): 20 | if os.path.getsize(output) == content_length: 21 | print(f"{output} exists. Download skip.") 22 | return 23 | 24 | saved_size = 0 25 | 26 | pbar = tqdm.tqdm(total=content_length) 27 | with open(output, "wb") as f: 28 | while 1: 29 | chunk = response.read(8192) 30 | if not chunk: 31 | break 32 | f.write(chunk) 33 | saved_size += len(chunk) 34 | pbar.update(len(chunk)) 35 | 36 | if saved_size != content_length: 37 | os.remove(output) 38 | raise BlockingIOError("fail to download. file cleared") 39 | -------------------------------------------------------------------------------- /utils/fs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Generator 4 | 5 | 6 | def traverse_folder(folder: Path) -> Generator[Path, Any, Any]: 7 | if os.path.exists(folder) and os.path.isdir(folder): 8 | children = sorted(os.listdir(folder)) 9 | for child in children: 10 | child_abs_path = Path(os.path.join(folder, child)) 11 | if os.path.isfile(child_abs_path): 12 | yield child_abs_path 13 | else: 14 | yield from traverse_folder(child_abs_path) 15 | else: 16 | raise ValueError(f"{folder} does not exist or is not a directory") 17 | -------------------------------------------------------------------------------- /utils/postprocess.py: -------------------------------------------------------------------------------- 1 | from moviepy import VideoFileClip, concatenate_videoclips, clips_array, ImageClip, ColorClip, concatenate_videoclips, vfx 2 | import os 3 | import cv2 4 | from pyarrow import duration 5 | 6 | 7 | def print_video_info(video, name): 8 | fps = video.fps 9 | duration = video.duration # 秒数 10 | frame_count = int(fps * duration) 11 | 12 | print(f"视频{name}:") 13 | print(f" 帧率: {fps} fps") 14 | print(f" 时长: {duration:.2f} 秒") 15 | print(f" 总帧数: {frame_count} 帧") 16 | print("") 17 | 18 | def images_to_video(image_dir, output_path, fps=24): 19 | # 获取图像文件并按自然顺序排序 20 | image_files = sorted([ 21 | f for f in os.listdir(image_dir) if f.endswith((".jpg", ".jpeg", ".png"))]) 22 | 23 | 24 | # 读取第一张图像获取尺寸 25 | first_image = cv2.imread(os.path.join(image_dir, image_files[0])) 26 | if first_image is None: 27 | raise ValueError("❌ 无法读取第一张图像。") 28 | height, width = first_image.shape[:2] 29 | 30 | # 设置视频编码器 31 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # .mp4 输出 32 | video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) 33 | 34 | for img_name in image_files: 35 | img_path = os.path.join(image_dir, img_name) 36 | frame = cv2.imread(img_path) 37 | 38 | # 自动 resize 不一致的尺寸 39 | if frame.shape[:2] != (height, width): 40 | frame = cv2.resize(frame, (width, height)) 41 | 42 | video_writer.write(frame) 43 | 44 | video_writer.release() 45 | print(f"✅ 视频保存成功: {output_path}") 46 | 47 | def concatenate_videos(): 48 | # 加载三个视频 49 | clip1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/cut1.mp4") 50 | clip2 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/cut2.mp4") 51 | # clip3 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/cut3.mp4") 52 | 53 | # 拼接视频 54 | final_clip = concatenate_videoclips([clip1, clip2]) 55 | 56 | # 输出到新文件 57 | final_clip.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/output_combined.mp4", codec="libx264") 58 | 59 | 60 | def array_videos(): 61 | # 加载两个视频 62 | clip1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/driving.mp4") 63 | clip2 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/animation.mp4") 64 | 65 | # 横向拼接(按帧对齐) 66 | final_clip = clips_array([[clip1, clip2]]) # 一行两列,clip1 左,clip2 右 67 | 68 | # 导出结果 69 | final_clip.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/side_by_side.mp4", codec="libx264") 70 | 71 | def set_fps(): 72 | clip = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/combined_with_transitions.mp4") 73 | 74 | # 设定新帧率(不会插帧) 75 | clip = clip.with_fps(24) 76 | 77 | # 写出时指定新帧率(帧数不变,播放速度加快) 78 | clip.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/combined_with_transitions_24.mp4", fps=24, codec="libx264") 79 | 80 | 81 | def add_static_image(): 82 | # 加载 side-by-side 视频 83 | video = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/output_fast_30fps.mp4") 84 | 85 | # 加载静态图像,并设置和视频一样的高度、持续整个视频时长 86 | image = ImageClip("/home/mengting/Desktop/tmp/champ/gpu_1/boy2_cropped.jpg").resized(height=video.h).with_duration(video.duration) 87 | 88 | # 横向拼接(静止图像 + 视频) 89 | final = clips_array([[image, video]]) 90 | 91 | # 导出结果 92 | final.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/final_output.mp4", codec="libx264") 93 | # 179 拼接, 360-179 = 181 94 | # 181和117拼接,181-117 = 64 95 | # 64和150拼接,150-64 = 86 96 | # video 7 and video 8 97 | # video 1 is 275, video2 is 96, video 3 is 360, video5 is 117, video6 is 150, video7 is 98, video8 is 102 98 | def print_info(): 99 | # 加载两个视频 100 | video1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person1/video_part1.mp4") 101 | video2 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person2/final_output.mp4") 102 | 103 | fps = video1.fps 104 | duration = video1.duration # 秒 105 | frame_count = int(fps * duration) 106 | print(f" 帧率 (fps): {fps}") 107 | print(f" 总时长 (秒): {duration:.2f}") 108 | print(f" 总帧数: {frame_count}") 109 | print("") 110 | 111 | def split_video(): 112 | video = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person8/final_output.mp4") 113 | 114 | # 获取帧率和总帧数 115 | fps = video.fps 116 | total_frames = int(video.duration * fps) # 或直接写 total_frames = 275 117 | 118 | # 设定拆分点 119 | first_part_frames = 98 120 | first_part_duration = first_part_frames / fps # 以秒为单位 121 | total_duration = video.duration 122 | 123 | # 截取第一段(前96帧) 124 | clip1 = video.subclipped(0, first_part_duration) 125 | 126 | # 截取第二段(剩下的帧) 127 | clip2 = video.subclipped(first_part_duration, total_duration) 128 | 129 | # 保存 130 | clip1.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/person8/video_part_1.mp4", codec="libx264", fps=fps) 131 | clip2.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/person8/video_part_2.mp4", codec="libx264", fps=fps) 132 | 133 | 134 | def up_and_down(): 135 | # 加载两个视频 136 | clip1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person7/final_output.mp4") 137 | clip2 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person8/video_part_1.mp4") 138 | 139 | 140 | # 上下拼接(垂直堆叠) 141 | final = clips_array([[clip1], [clip2]]) 142 | 143 | # 输出拼接后的视频 144 | final.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_5.mp4", codec="libx264", fps=clip1.fps) 145 | 146 | 147 | def concate_all(): 148 | # 视频文件名列表 149 | video_files = ["/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_1.mp4", 150 | "/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_2.mp4", 151 | "/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_3.mp4", 152 | "/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_4.mp4", 153 | "/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_5.mp4"] 154 | 155 | # 设置淡入/淡出时长(秒) 156 | fade_duration = 1.0 157 | black_duration = 0.5 158 | 159 | # 加载视频,添加淡入淡出 160 | clips = [] 161 | for idx, file in enumerate(video_files): 162 | clip = VideoFileClip(file) 163 | clip = clip.with_effects([vfx.FadeIn(duration=fade_duration)]) 164 | clip = clip.with_effects([vfx.FadeOut(duration=black_duration)]) 165 | clips.append(clip) 166 | 167 | # 在中间插入黑色片段(不加最后一个) 168 | if idx < len(video_files) - 1: 169 | black = ColorClip(size=clip.size, color=(0, 0, 0), duration=black_duration) 170 | clips.append(black) 171 | 172 | # 拼接所有片段 173 | final_clip = concatenate_videoclips(clips, method="compose") 174 | 175 | # 导出视频 176 | final_clip.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/combined_with_transitions.mp4", codec="libx264", fps=clips[0].fps) 177 | 178 | if __name__ == '__main__': 179 | #concatenate_videos() 180 | #images_to_video('/home/mengting/Desktop/frames_1500_updated/1fsFQ2gF4oE_0/images', '/home/mengting/Desktop/tmp/champ/gpu_1/driving.mp4') 181 | #array_videos() 182 | # video1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/driving.mp4") 183 | # print_video_info(video1, "driving") 184 | set_fps() 185 | #add_static_image() 186 | #print_info() 187 | #split_video() 188 | #up_and_down() 189 | #concate_all() -------------------------------------------------------------------------------- /utils/tb_tracker.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from accelerate.tracking import GeneralTracker, on_main_process 3 | import os 4 | from typing import Union 5 | 6 | 7 | class TbTracker(GeneralTracker): 8 | 9 | name = "tensorboard" 10 | requires_logging_directory = True 11 | 12 | @on_main_process 13 | def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], 14 | **kwargs): 15 | super().__init__() 16 | self.run_name = run_name 17 | self.logging_dir = os.path.join(logging_dir, run_name) 18 | self.writer = SummaryWriter(self.logging_dir, **kwargs) 19 | 20 | @property 21 | def tracker(self): 22 | return self.writer 23 | 24 | @on_main_process 25 | def add_scalar(self, tag, scalar_value, **kwargs): 26 | self.writer.add_scalar(tag=tag, scalar_value=scalar_value, **kwargs) 27 | 28 | @on_main_process 29 | def add_text(self, tag, text_string, **kwargs): 30 | self.writer.add_text(tag=tag, text_string=text_string, **kwargs) 31 | 32 | @on_main_process 33 | def add_figure(self, tag, figure, **kwargs): 34 | self.writer.add_figure(tag=tag, figure=figure, **kwargs) 35 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import sys 5 | from pathlib import Path 6 | 7 | # import av 8 | import numpy as np 9 | import torch 10 | import torchvision 11 | from einops import rearrange 12 | from PIL import Image 13 | 14 | 15 | def seed_everything(seed): 16 | import random 17 | import numpy as np 18 | 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed % (2**32)) 22 | random.seed(seed) 23 | 24 | def delete_additional_ckpt(base_path, num_keep): 25 | dirs = [] 26 | for d in os.listdir(base_path): 27 | if d.startswith("checkpoint-"): 28 | dirs.append(d) 29 | num_tot = len(dirs) 30 | if num_tot <= num_keep: 31 | return 32 | # ensure ckpt is sorted and delete the ealier! 33 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] 34 | for d in del_dirs: 35 | path_to_dir = osp.join(base_path, d) 36 | if osp.exists(path_to_dir): 37 | shutil.rmtree(path_to_dir) 38 | 39 | def compute_snr(noise_scheduler, timesteps): 40 | """ 41 | Computes SNR as per 42 | https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 43 | """ 44 | alphas_cumprod = noise_scheduler.alphas_cumprod 45 | sqrt_alphas_cumprod = alphas_cumprod**0.5 46 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 47 | 48 | # Expand the tensors. 49 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 50 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ 51 | timesteps 52 | ].float() 53 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 54 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 55 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 56 | 57 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( 58 | device=timesteps.device 59 | )[timesteps].float() 60 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 61 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 62 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 63 | 64 | # Compute SNR. 65 | snr = (alpha / sigma) ** 2 66 | return snr 67 | -------------------------------------------------------------------------------- /utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import torch.nn.functional as F 6 | from PIL import Image 7 | from pathlib import Path 8 | import imageio 9 | from einops import rearrange 10 | import torchvision.transforms as transforms 11 | 12 | 13 | def save_videos_from_pil(pil_images, path, fps=15, crf=23): 14 | 15 | save_fmt = Path(path).suffix 16 | os.makedirs(os.path.dirname(path), exist_ok=True) 17 | 18 | if save_fmt == ".mp4": 19 | with imageio.get_writer(path, fps=fps) as writer: 20 | for img in pil_images: 21 | img_array = np.array(img) # Convert PIL Image to numpy array 22 | writer.append_data(img_array) 23 | 24 | elif save_fmt == ".gif": 25 | pil_images[0].save( 26 | fp=path, 27 | format="GIF", 28 | append_images=pil_images[1:], 29 | save_all=True, 30 | duration=(1 / fps * 1000), 31 | loop=0, 32 | ) 33 | else: 34 | raise ValueError("Unsupported file type. Use .mp4 or .gif.") 35 | 36 | 37 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=24): 38 | videos = rearrange(videos, "b c t h w -> t b c h w") 39 | height, width = videos.shape[-2:] 40 | outputs = [] 41 | 42 | for x in videos: 43 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) 44 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) 45 | if rescale: 46 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 47 | x = (x * 255).numpy().astype(np.uint8) 48 | x = Image.fromarray(x) 49 | 50 | outputs.append(x) 51 | 52 | os.makedirs(os.path.dirname(path), exist_ok=True) 53 | 54 | save_videos_from_pil(outputs, path, fps) 55 | 56 | 57 | def resize_tensor_frames(video_tensor, new_size): 58 | B, C, video_length, H, W = video_tensor.shape 59 | # Reshape video tensor to combine batch and frame dimensions: (B*F, C, H, W) 60 | video_tensor_reshaped = video_tensor.reshape(-1, C, H, W) 61 | # Resize using interpolate 62 | resized_frames = F.interpolate( 63 | video_tensor_reshaped, size=new_size, mode="bilinear", align_corners=False 64 | ) 65 | resized_video = resized_frames.reshape(B, C, video_length, new_size[0], new_size[1]) 66 | 67 | return resized_video 68 | 69 | 70 | def pil_list_to_tensor(image_list, size=None): 71 | to_tensor = transforms.ToTensor() 72 | if size is not None: 73 | tensor_list = [to_tensor(img.resize(size[::-1])) for img in image_list] 74 | else: 75 | tensor_list = [to_tensor(img) for img in image_list] 76 | stacked_tensor = torch.stack(tensor_list, dim=0) 77 | tensor = stacked_tensor.permute(1, 0, 2, 3) 78 | return tensor 79 | 80 | 81 | def conatenate_into_video(): 82 | gt_list = [] 83 | gt_root = '/home/mengting/projects/champ_abls/no_exp_coeff/results/output_images' 84 | imgs = sorted(os.listdir(gt_root)) 85 | for img in imgs: 86 | cur_img_path = os.path.join(gt_root, img) 87 | tmp_img = Image.open(cur_img_path) 88 | tmp_img = tmp_img.resize((512, 512)) 89 | tmp_img = transforms.ToTensor()(tmp_img) 90 | gt_list.append(tmp_img) 91 | gt_list = torch.stack(gt_list, dim=1).unsqueeze(0) 92 | print(gt_list.shape) 93 | 94 | ref_image_path = '/home/mengting/Desktop/frames_1500_updated/4Z7qKXu9Sck_2/images/frame_0000.jpg' 95 | ref_image_pil = Image.open(ref_image_path) 96 | ref_image_w, ref_image_h = ref_image_pil.size 97 | video_length = len(imgs) 98 | ref_video_tensor = transforms.ToTensor()(ref_image_pil)[None, :, None, ...].repeat( 99 | 1, 1, video_length, 1, 1 100 | ) 101 | 102 | drive_list = [] 103 | guidance_path = '/home/mengting/Desktop/frames_new_1500/2yj1P52T1X8_4/images' 104 | imgs = sorted(os.listdir(guidance_path)) 105 | for i, img in enumerate(imgs): 106 | cur_img_path = os.path.join(guidance_path, img) 107 | tmp_img = Image.open(cur_img_path) 108 | tmp_img = transforms.ToTensor()(tmp_img) 109 | drive_list.append(tmp_img) 110 | if len(drive_list) == video_length: 111 | break 112 | drive_list = torch.stack(drive_list, dim=1).unsqueeze(0) 113 | print(drive_list.shape, ref_video_tensor.shape) 114 | 115 | save_dir = '/home/mengting/projects/champ_abls/no_exp_coeff/results/comparison' 116 | grid_video = torch.cat([drive_list, ref_video_tensor, gt_list], dim=0) 117 | save_videos_grid(grid_video, os.path.join(save_dir, "grid_wdrive_aniportrait.mp4")) 118 | 119 | if __name__ == '__main__': 120 | conatenate_into_video() --------------------------------------------------------------------------------