├── README.md
├── assets
├── boy.jpeg
├── boy_cropped.jpg
├── combined_with_transitions_24.mp4
└── driving.mp4
├── configs
└── inference
│ └── inference.yaml
├── crop_process.py
├── data_utils
├── __pycache__
│ ├── datasets_faceswap.cpython-310.pyc
│ └── transfer_utils.cpython-310.pyc
├── datasets_faceswap.py
└── transfer_utils.py
├── decalib
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ └── deca_with_smirk.cpython-310.pyc
├── datasets
│ ├── __pycache__
│ │ ├── datasets.cpython-310.pyc
│ │ └── detectors.cpython-310.pyc
│ ├── aflw2000.py
│ ├── build_datasets.py
│ ├── datasets.py
│ ├── detectors.py
│ ├── ethnicity.py
│ ├── now.py
│ ├── train_datasets.py
│ ├── vggface.py
│ └── vox.py
├── deca_with_smirk.py
├── models
│ ├── FLAME.py
│ ├── __pycache__
│ │ ├── FLAME.cpython-310.pyc
│ │ ├── decoders.cpython-310.pyc
│ │ ├── encoders.cpython-310.pyc
│ │ ├── lbs.cpython-310.pyc
│ │ └── resnet.cpython-310.pyc
│ ├── decoders.py
│ ├── encoders.py
│ ├── frnet.py
│ ├── lbs.py
│ └── resnet.py
├── smirk
│ ├── __pycache__
│ │ ├── mediapipe_utils.cpython-310.pyc
│ │ └── smirk_encoder.cpython-310.pyc
│ ├── face_landmarker.task
│ ├── mediapipe_utils.py
│ ├── smirk_encoder.py
│ └── utils
│ │ ├── masking.py
│ │ └── utils.py
├── trainer.py
└── utils
│ ├── __pycache__
│ ├── config.cpython-310.pyc
│ ├── renderer.cpython-310.pyc
│ ├── rotation_converter.cpython-310.pyc
│ ├── tensor_cropper.cpython-310.pyc
│ └── util.cpython-310.pyc
│ ├── config.py
│ ├── lossfunc.py
│ ├── rasterizer
│ ├── INSTALL.md
│ ├── __init__.py
│ ├── setup.py
│ ├── standard_rasterize_cuda.cpp
│ └── standard_rasterize_cuda_kernel.cu
│ ├── renderer.py
│ ├── rotation_converter.py
│ ├── tensor_cropper.py
│ ├── trainer.py
│ └── util.py
├── inference.py
├── models
├── __pycache__
│ ├── attention.cpython-310.pyc
│ ├── exp_encoder.cpython-310.pyc
│ ├── guidance_encoder.cpython-310.pyc
│ ├── mgportrait_model.cpython-310.pyc
│ ├── motion_module.cpython-310.pyc
│ ├── mutual_self_attention.cpython-310.pyc
│ ├── resnet.cpython-310.pyc
│ ├── transformer_2d.cpython-310.pyc
│ ├── transformer_3d.cpython-310.pyc
│ ├── unet_2d_blocks.cpython-310.pyc
│ ├── unet_2d_condition.cpython-310.pyc
│ ├── unet_3d.cpython-310.pyc
│ └── unet_3d_blocks.cpython-310.pyc
├── attention.py
├── exp_encoder.py
├── guidance_encoder.py
├── mgportrait_model.py
├── motion_module.py
├── mutual_self_attention.py
├── resnet.py
├── transformer_2d.py
├── transformer_3d.py
├── unet_2d_blocks.py
├── unet_2d_condition.py
├── unet_3d.py
└── unet_3d_blocks.py
├── pipelines
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── context.cpython-310.pyc
│ ├── pipe_utils.cpython-310.pyc
│ └── pipeline_aggregation.cpython-310.pyc
├── context.py
├── pipe_utils.py
└── pipeline_aggregation.py
├── render_and_transfer.py
├── requirements.txt
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-310.pyc
└── video_utils.cpython-310.pyc
├── download.py
├── fs.py
├── postprocess.py
├── tb_tracker.py
├── util.py
└── video_utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # MagicPortrait
2 |
3 | **MagicPortrait: Temporally Consistent Face Reenactment with 3D Geometric Guidance**
4 |
5 | [Mengting Wei](),
6 | [Yante Li](),
7 | [Tuomas Varanka](),
8 | [Yan Jiang](),
9 | [Guoying Zhao]()
10 |
11 |
12 | _[arXiv](https://arxiv.org/abs/2504.21497) | [Model](https://huggingface.co/mengtingwei/MagicPortrait)_
13 |
14 | This repository contains the example inference script for the MagicPortrait-preview model.
15 |
16 | https://github.com/user-attachments/assets/9471fdd9-948a-47dd-a632-2adfc631be50
17 |
18 | ## Installation
19 |
20 | ```bash
21 | conda create -n mgp python=3.10 -y
22 | conda activate mgp
23 | pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
24 | pip install -r requirements.txt
25 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath
26 | pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/download.html
27 | ```
28 |
29 | ## Inference of the Model
30 |
31 |
32 | ### Step 1: Download pre-trained models
33 |
34 | Download our models from Huggingface.
35 |
36 | ```bash
37 | huggingface-cli download --resume-download mengtingwei/MagicPortrait --local-dir ./pre_trained
38 | ```
39 | ### Step 2: Setup necessary libraries for face motion transfer
40 | 0. Put the downloaded `third_party_files` in the last step under the project directory `./`.
41 | 1. Visit [DECA Github](https://github.com/yfeng95/DECA?tab=readme-ov-file) to download the pretrained `deca_model.tar`.
42 | 2. Visit [FLAME website](https://flame.is.tue.mpg.de/download.php) to download `FLAME 2020` and extract `generic_model.pkl`.
43 | 3. Visit [FLAME website](https://flame.is.tue.mpg.de/download.php) to download `FLAME texture space` and extract `FLAME_texture.npz`.
44 | 4. Visit [DECA' data page](https://github.com/yfeng95/DECA/tree/master/data) and download all files.
45 | 5. Visit [SMIRK website](https://github.com/georgeretsi/smirk) to download `SMIRK_em1.pt`.
46 | 6. Place the files in their corresponding locations as specified below.
47 |
48 | ```plaintext
49 | decalib
50 | data/
51 | deca_model.tar
52 | generic_model.pkl
53 | FLAME_texture.npz
54 | fixed_displacement_256.npy
55 | head_template.obj
56 | landmark_embedding.npy
57 | mean_texture.jpg
58 | texture_data_256.npy
59 | uv_face_eye_mask.png
60 | uv_face_mask.png
61 | ...
62 | smirk/
63 | pretrained_models/
64 | SMIRK_em1.pt
65 | ...
66 | ...
67 | ```
68 |
69 | ### Step 3: Process the identity image and driving video
70 |
71 | > As our model is designed to focus only on the face,
72 | > you should crop the face from your images or videos if they are full-body shots.
73 | > However, **if your images or videos already contain only the face and the aspect ratio is approximately 1:1,
74 | > you can simply resize them into resolution of 512 $\times$ 512 without doing the following crop (1 and 2) steps.**
75 |
76 | 1. Crop the face from an image:
77 |
78 | ```python
79 | python crop_process.py --sign image --img_path './assets/boy.jpeg' --save_path './assets/boy_cropped.jpg'
80 | ```
81 |
82 | 2. Crop the faces sequence from the driving video.
83 |
84 | * If you have a video
85 | ```bash
86 | mkdir ./assets/driving_images
87 | ffmpeg -i ./assets/driving.mp4 ./assets/driving_images/frame_%04d.jpg
88 | ```
89 | Crop face from the driving images.
90 | ```python
91 | python crop_process.py --sign video --video_path './assets/driving_images' --video_imgs_dir './assets/driving_images_cropped'
92 | ```
93 |
94 | 3. Retrieve guidance images using DECA and SMIRK models.
95 |
96 | ```python
97 | python render_and_transfer.py --sor_img './assets/boy_cropped.jpg' --driving_path './assets/driving_images_cropped' --save_name example1
98 | ```
99 | The guidance will be saved in the `./transfers` directory.
100 |
101 | ### Step 4: Inference
102 |
103 | Update the model and image directories in `./configs/inference/inference.yaml` to match your own file locations.
104 |
105 | Then run:
106 | ```python
107 | python inference.py
108 | ```
109 |
110 | ## Acknowledgement
111 |
112 | Our work is made possible thanks to open-source pioneering
113 | 3D face reconstruction works (including [DECA](https://github.com/yfeng95/DECA?tab=readme-ov-file)
114 | and [SMIRK](https://github.com/georgeretsi/smirk)) and
115 | a high-quality talking-video dataset [CelebV-HQ](https://celebv-hq.github.io).
116 |
117 | ## Contact
118 | Open an issue here or email [mengting.wei@oulu.fi]().
--------------------------------------------------------------------------------
/assets/boy.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/assets/boy.jpeg
--------------------------------------------------------------------------------
/assets/boy_cropped.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/assets/boy_cropped.jpg
--------------------------------------------------------------------------------
/assets/combined_with_transitions_24.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/assets/combined_with_transitions_24.mp4
--------------------------------------------------------------------------------
/assets/driving.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/assets/driving.mp4
--------------------------------------------------------------------------------
/configs/inference/inference.yaml:
--------------------------------------------------------------------------------
1 | exp_name: Animation
2 | width: 512
3 | height: 512
4 | data:
5 | ref_image_path: '.../assets/girl1_cropped.jpg' # reference image path
6 | guidance_data_folder: '.../transfers/example1' # corresponding motion sequence folder
7 | frame_range: [0, 100] # [Optional] specify a frame range: [min_frame_idx, max_frame_idx] to select a clip from a motion sequence
8 | seed: 42
9 |
10 | base_model_path: '.../stable-diffusion-v1-5'
11 | vae_model_path: '.../sd-vae-ft-mse'
12 |
13 |
14 | ckpt_dir: '.../ckpts/'
15 | motion_module_path: '.../ckpts/motion_module-47360.pth'
16 |
17 | num_inference_steps: 20
18 | guidance_scale: 3.5
19 | enable_zero_snr: true
20 | weight_dtype: "fp16"
21 |
22 | guidance_types:
23 | - 'depth'
24 | - 'normal'
25 | - 'render'
26 |
27 | noise_scheduler_kwargs:
28 | num_train_timesteps: 1000
29 | beta_start: 0.00085
30 | beta_end: 0.012
31 | beta_schedule: "linear"
32 | steps_offset: 1
33 | clip_sample: false
34 |
35 | unet_additional_kwargs:
36 | use_inflated_groupnorm: true
37 | unet_use_cross_frame_attention: false
38 | unet_use_temporal_attention: false
39 | use_motion_module: true
40 | motion_module_resolutions:
41 | - 1
42 | - 2
43 | - 4
44 | - 8
45 | motion_module_mid_block: true
46 | motion_module_decoder_only: false
47 | motion_module_type: Vanilla
48 | motion_module_kwargs:
49 | num_attention_heads: 8
50 | num_transformer_block: 1
51 | attention_block_types:
52 | - Temporal_Self
53 | - Temporal_Self
54 | temporal_position_encoding: true
55 | temporal_position_encoding_max_len: 32
56 | temporal_attention_dim_div: 1
57 |
58 | guidance_encoder_kwargs:
59 | guidance_embedding_channels: 320
60 | guidance_input_channels: 3
61 | block_out_channels: [16, 32, 96, 256]
62 |
63 | enable_xformers_memory_efficient_attention: false
64 |
--------------------------------------------------------------------------------
/crop_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import cv2
4 | import numpy as np
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 | from insightface.app import FaceAnalysis
8 | from torchvision.utils import save_image
9 |
10 | import data_utils.datasets_faceswap as datasets_faceswap
11 |
12 | pil2tensor = transforms.Compose([transforms.ToTensor(), transforms.Resize(512)])
13 |
14 | pil2tensor = transforms.ToTensor()
15 |
16 | app = FaceAnalysis(name='antelopev2', root=os.path.join('./', 'third_party_files'),
17 | providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
18 | app.prepare(ctx_id=0, det_size=(640, 640))
19 |
20 |
21 | def get_bbox(dets, crop_ratio):
22 | if crop_ratio > 0:
23 | bbox = dets[0:4]
24 | bbox_size = max(bbox[2] - bbox[0], bbox[2] - bbox[0])
25 | bbox_x = 0.5 * (bbox[2] + bbox[0])
26 | bbox_y = 0.5 * (bbox[3] + bbox[1])
27 | x1 = bbox_x - bbox_size * crop_ratio
28 | x2 = bbox_x + bbox_size * crop_ratio
29 | y1 = bbox_y - bbox_size * crop_ratio
30 | y2 = bbox_y + bbox_size * crop_ratio
31 | bbox_pts4 = np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32)
32 | else:
33 | # original box
34 | bbox = dets[0:4].reshape((2, 2))
35 | bbox_pts4 = datasets_faceswap.get_box_lm4p(bbox)
36 | return bbox_pts4
37 |
38 |
39 |
40 | def crop_one_image(args):
41 | cur_img_sor_path = args.img_path
42 | im_pil_sor = Image.open(cur_img_sor_path).convert("RGB")
43 | face_info_sor = app.get(cv2.cvtColor(np.array(im_pil_sor), cv2.COLOR_RGB2BGR))
44 | assert len(face_info_sor) >= 1, 'The input image must contain a face!'
45 | if len(face_info_sor) > 1:
46 | print('The input image contain more than one face, we will only use the maximum face')
47 | face_info_sor = \
48 | sorted(face_info_sor, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[-1]
49 | dets_sor= face_info_sor['bbox']
50 |
51 | bbox_pst_sor = get_bbox(dets_sor, crop_ratio=0.75)
52 |
53 | warp_mat_crop_sor = datasets_faceswap.transformation_from_points(bbox_pst_sor,
54 | datasets_faceswap.mean_box_lm4p_512)
55 | im_crop512_sor = cv2.warpAffine(np.array(im_pil_sor), warp_mat_crop_sor, (512, 512), flags=cv2.INTER_LINEAR)
56 |
57 | im_pil_sor = Image.fromarray(im_crop512_sor)
58 | im_pil_sor = pil2tensor(im_pil_sor)
59 | save_image(im_pil_sor, args.save_path)
60 |
61 |
62 | def crop_a_directory(args):
63 | os.makedirs(args.video_imgs_dir, exist_ok=True)
64 | imgs = sorted(os.listdir(args.video_path))
65 | first_img_path = os.path.join(args.video_path, imgs[0])
66 | im_pil_sor = Image.open(first_img_path).convert("RGB")
67 | face_info_sor = app.get(cv2.cvtColor(np.array(im_pil_sor), cv2.COLOR_RGB2BGR))
68 | assert len(face_info_sor) >= 1, 'The input image must contain a face!'
69 | if len(face_info_sor) > 1:
70 | print('The input image contain more than one face, we will only use the maximum face')
71 | face_info_sor = \
72 | sorted(face_info_sor, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[-1]
73 | dets_sor = face_info_sor['bbox']
74 | for img in imgs:
75 | cur_img_path = os.path.join(args.video_path, img)
76 | im_pil_sor = Image.open(cur_img_path).convert("RGB")
77 | bbox_pst_sor = get_bbox(dets_sor, crop_ratio=0.75)
78 |
79 | warp_mat_crop_sor = datasets_faceswap.transformation_from_points(bbox_pst_sor,
80 | datasets_faceswap.mean_box_lm4p_512)
81 | im_crop512_sor = cv2.warpAffine(np.array(im_pil_sor), warp_mat_crop_sor, (512, 512), flags=cv2.INTER_LINEAR)
82 |
83 | im_pil_sor = Image.fromarray(im_crop512_sor)
84 | im_pil_sor = pil2tensor(im_pil_sor)
85 | save_image(im_pil_sor, os.path.join(args.video_imgs_dir, img))
86 |
87 |
88 | if __name__ == '__main__':
89 | parser = argparse.ArgumentParser()
90 | parser.add_argument(
91 | "--sign",
92 | type=str,
93 | default='image',
94 | required=False
95 | )
96 | # ********************************* identity image ***************************************
97 | parser.add_argument(
98 | "--img_path",
99 | type=str,
100 | default='/home/mengting/projects/process_scripts/test_images/boy2.jpeg',
101 | required=False
102 | )
103 | parser.add_argument(
104 | "--save_path",
105 | type=str,
106 | default='/home/mengting/projects/process_scripts/test_images/boy2_cropped.jpg',
107 | required=False
108 | )
109 | # ********************************** driving video **************************************
110 | parser.add_argument(
111 | "--video_path",
112 | type=str,
113 | default='/home/mengting/projects/process_scripts/test_images/target_images1',
114 | required=False
115 | )
116 | parser.add_argument(
117 | "--video_imgs_dir",
118 | type=str,
119 | default='/home/mengting/projects/process_scripts/test_images/target_images1_cropped',
120 | required=False
121 | )
122 | args = parser.parse_args()
123 | if args.sign == 'image':
124 | crop_one_image(args)
125 | elif args.sign == 'video':
126 | crop_a_directory(args)
127 | else:
128 | assert 'sign is invalid'
--------------------------------------------------------------------------------
/data_utils/__pycache__/datasets_faceswap.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/data_utils/__pycache__/datasets_faceswap.cpython-310.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/transfer_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/data_utils/__pycache__/transfer_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/data_utils/datasets_faceswap.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import cv2
4 | from PIL import ImageFile
5 |
6 |
7 | ImageFile.LOAD_TRUNCATED_IMAGES = True
8 |
9 | mean_face_lm5p_256 = np.array([
10 | [(30.2946+8)*2+16, 51.6963*2], # left eye pupil
11 | [(65.5318+8)*2+16, 51.5014*2], # right eye pupil
12 | [(48.0252+8)*2+16, 71.7366*2], # nose tip
13 | [(33.5493+8)*2+16, 92.3655*2], # left mouth corner
14 | [(62.7299+8)*2+16, 92.2041*2], # right mouth corner
15 | ], dtype=np.float32)
16 |
17 |
18 |
19 | mean_box_lm4p_512 = np.array([
20 | [80, 80],
21 | [80, 432],
22 | [432, 432],
23 | [432, 80],
24 | ], dtype=np.float32)
25 |
26 |
27 |
28 | def get_box_lm4p(pts):
29 | x1 = np.min(pts[:,0])
30 | x2 = np.max(pts[:,0])
31 | y1 = np.min(pts[:,1])
32 | y2 = np.max(pts[:,1])
33 |
34 | x_center = (x1+x2)*0.5
35 | y_center = (y1+y2)*0.5
36 | box_size = max(x2-x1, y2-y1)
37 |
38 | x1 = x_center-0.5*box_size
39 | x2 = x_center+0.5*box_size
40 | y1 = y_center-0.5*box_size
41 | y2 = y_center+0.5*box_size
42 |
43 | return np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32)
44 |
45 |
46 | def get_affine_transform(target_face_lm5p, mean_lm5p):
47 | mat_warp = np.zeros((2,3))
48 | A = np.zeros((4,4))
49 | B = np.zeros((4))
50 | for i in range(5):
51 | #sa[0][0] += a[i].x*a[i].x + a[i].y*a[i].y;
52 | A[0][0] += target_face_lm5p[i][0] * target_face_lm5p[i][0] + target_face_lm5p[i][1] * target_face_lm5p[i][1]
53 | #sa[0][2] += a[i].x;
54 | A[0][2] += target_face_lm5p[i][0]
55 | #sa[0][3] += a[i].y;
56 | A[0][3] += target_face_lm5p[i][1]
57 |
58 | #sb[0] += a[i].x*b[i].x + a[i].y*b[i].y;
59 | B[0] += target_face_lm5p[i][0] * mean_lm5p[i][0] + target_face_lm5p[i][1] * mean_lm5p[i][1]
60 | #sb[1] += a[i].x*b[i].y - a[i].y*b[i].x;
61 | B[1] += target_face_lm5p[i][0] * mean_lm5p[i][1] - target_face_lm5p[i][1] * mean_lm5p[i][0]
62 | #sb[2] += b[i].x;
63 | B[2] += mean_lm5p[i][0]
64 | #sb[3] += b[i].y;
65 | B[3] += mean_lm5p[i][1]
66 |
67 | #sa[1][1] = sa[0][0];
68 | A[1][1] = A[0][0]
69 | #sa[2][1] = sa[1][2] = -sa[0][3];
70 | A[2][1] = A[1][2] = -A[0][3]
71 | #sa[3][1] = sa[1][3] = sa[2][0] = sa[0][2];
72 | A[3][1] = A[1][3] = A[2][0] = A[0][2]
73 | #sa[2][2] = sa[3][3] = count;
74 | A[2][2] = A[3][3] = 5
75 | #sa[3][0] = sa[0][3];
76 | A[3][0] = A[0][3]
77 |
78 | _, mat23 = cv2.solve(A, B, flags=cv2.DECOMP_SVD)
79 | mat_warp[0][0] = mat23[0]
80 | mat_warp[1][1] = mat23[0]
81 | mat_warp[0][1] = -mat23[1]
82 | mat_warp[1][0] = mat23[1]
83 | mat_warp[0][2] = mat23[2]
84 | mat_warp[1][2] = mat23[3]
85 |
86 | return mat_warp
87 |
88 |
89 |
90 |
91 | def transformation_from_points(points1, points2):
92 | points1 = np.float64(np.matrix([[point[0], point[1]] for point in points1]))
93 | points2 = np.float64(np.matrix([[point[0], point[1]] for point in points2]))
94 |
95 | points1 = points1.astype(np.float64)
96 | points2 = points2.astype(np.float64)
97 | c1 = np.mean(points1, axis=0)
98 | c2 = np.mean(points2, axis=0)
99 | points1 -= c1
100 | points2 -= c2
101 | s1 = np.std(points1)
102 | s2 = np.std(points2)
103 | points1 /= s1
104 | points2 /= s2
105 | #points2 = np.array(points2)
106 | #write_pts('pt2.txt', points2)
107 | U, S, Vt = np.linalg.svd(points1.T * points2)
108 | R = (U * Vt).T
109 | return np.array(np.vstack([np.hstack(((s2 / s1) * R,c2.T - (s2 / s1) * R * c1.T)),np.matrix([0., 0., 1.])])[:2])
110 |
111 |
--------------------------------------------------------------------------------
/data_utils/transfer_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | import cv2
8 | import scipy
9 | from skimage.io import imread, imsave
10 | from skimage.transform import estimate_transform, warp, resize, rescale
11 | from glob import glob
12 | import scipy.io
13 | from decalib.datasets import datasets
14 | from torchvision.utils import save_image
15 | from decalib.datasets import detectors
16 | import shutil
17 |
18 |
19 | face_detector = detectors.FAN()
20 | scale = 1.3
21 | resolution_inp = 224
22 |
23 |
24 |
25 | def bbox2point(left, right, top, bottom, type='bbox'):
26 |
27 | if type =='kpt68':
28 | old_size = (right - left + bottom - top) / 2 * 1.1
29 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])
30 | elif type =='bbox':
31 | old_size = (right - left + bottom - top ) /2
32 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size *0.12])
33 | else:
34 | raise NotImplementedError
35 | return old_size, center
36 |
37 |
38 |
39 |
40 |
41 | def get_image_dict(img_path, size, iscrop):
42 | img_name = img_path.split('/')[-1]
43 | im = imread(img_path)
44 | if size is not None: # size = 256
45 | im = (resize(im, (size, size), anti_aliasing=True) * 255.).astype(np.uint8)
46 | # (256, 256, 3)
47 | image = np.array(im)
48 | if len(image.shape) == 2:
49 | image = image[:, :, None].repeat(1, 1, 3)
50 | if len(image.shape) == 3 and image.shape[2] > 3:
51 | image = image[:, :, :3]
52 |
53 | h, w, _ = image.shape
54 | if iscrop: # true
55 | # provide kpt as txt file, or mat file (for AFLW2000)
56 | kpt_matpath = os.path.splitext(img_path)[0] + '.mat'
57 | kpt_txtpath = os.path.splitext(img_path)[0] + '.txt'
58 | if os.path.exists(kpt_matpath):
59 | kpt = scipy.io.loadmat(kpt_matpath)['pt3d_68'].T
60 | left = np.min(kpt[:, 0])
61 | right = np.max(kpt[:, 0])
62 | top = np.min(kpt[:, 1])
63 | bottom = np.max(kpt[:, 1])
64 | old_size, center = bbox2point(left, right, top, bottom, type='kpt68')
65 | elif os.path.exists(kpt_txtpath):
66 | kpt = np.loadtxt(kpt_txtpath)
67 | left = np.min(kpt[:, 0])
68 | right = np.max(kpt[:, 0])
69 | top = np.min(kpt[:, 1])
70 | bottom = np.max(kpt[:, 1])
71 | old_size, center = bbox2point(left, right, top, bottom, type='kpt68')
72 | else:
73 | bbox, bbox_type = face_detector.run(image)
74 | if len(bbox) < 4:
75 | print('no face detected! run original image')
76 | left = 0
77 | right = h - 1
78 | top = 0
79 | bottom = w - 1
80 | else:
81 | left = bbox[0]
82 | right = bbox[2]
83 | top = bbox[1]
84 | bottom = bbox[3]
85 | old_size, center = bbox2point(left, right, top, bottom, type=bbox_type)
86 | size = int(old_size * scale)
87 | src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2],
88 | [center[0] + size / 2, center[1] - size / 2]])
89 | else:
90 | src_pts = np.array([[0, 0], [0, h - 1], [w - 1, 0]])
91 | # DST_PTS = np.array([[0, 0], [0, h-1], [w-1, 0]])
92 | DST_PTS = np.array([[0, 0], [0, resolution_inp - 1], [resolution_inp - 1, 0]])
93 | tform = estimate_transform('similarity', src_pts, DST_PTS)
94 |
95 | image = image / 255.
96 |
97 | dst_image = warp(image, tform.inverse, output_shape=(resolution_inp, resolution_inp))
98 | dst_image = dst_image.transpose(2, 0, 1)
99 | return {'image': torch.tensor(dst_image).float(),
100 | 'imagename': img_name,
101 | 'tform': torch.tensor(tform.params).float(),
102 | 'original_image': torch.tensor(image.transpose(2, 0, 1)).float(),
103 | }
104 |
105 |
106 |
107 | def check_face(img_path, size):
108 | im = imread(img_path)
109 | if size is not None: # size = 256
110 | im = (resize(im, (size, size), anti_aliasing=True) * 255.).astype(np.uint8)
111 | # (256, 256, 3)
112 | image = np.array(im)
113 | bbox, bbox_type = face_detector.run(image)
114 | if len(bbox) < 4:
115 | return False
116 | return True
117 |
118 |
--------------------------------------------------------------------------------
/decalib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/__init__.py
--------------------------------------------------------------------------------
/decalib/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/__pycache__/deca_with_smirk.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/__pycache__/deca_with_smirk.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/datasets/__pycache__/datasets.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/datasets/__pycache__/datasets.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/datasets/__pycache__/detectors.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/datasets/__pycache__/detectors.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/datasets/aflw2000.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 | import scipy.io
12 |
13 | class AFLW2000(Dataset):
14 | def __init__(self, testpath='/ps/scratch/yfeng/Data/AFLW2000/GT', crop_size=224):
15 | '''
16 | data class for loading AFLW2000 dataset
17 | make sure each image has corresponding mat file, which provides cropping infromation
18 | '''
19 | if os.path.isdir(testpath):
20 | self.imagepath_list = glob(testpath + '/*.jpg') + glob(testpath + '/*.png')
21 | elif isinstance(testpath, list):
22 | self.imagepath_list = testpath
23 | elif os.path.isfile(testpath) and (testpath[-3:] in ['jpg', 'png']):
24 | self.imagepath_list = [testpath]
25 | else:
26 | print('please check the input path')
27 | exit()
28 | print('total {} images'.format(len(self.imagepath_list)))
29 | self.imagepath_list = sorted(self.imagepath_list)
30 | self.crop_size = crop_size
31 | self.scale = 1.6
32 | self.resolution_inp = crop_size
33 |
34 | def __len__(self):
35 | return len(self.imagepath_list)
36 |
37 | def __getitem__(self, index):
38 | imagepath = self.imagepath_list[index]
39 | imagename = imagepath.split('/')[-1].split('.')[0]
40 | image = imread(imagepath)[:,:,:3]
41 | kpt = scipy.io.loadmat(imagepath.replace('jpg', 'mat'))['pt3d_68'].T
42 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]);
43 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
44 |
45 | h, w, _ = image.shape
46 | old_size = (right - left + bottom - top)/2
47 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
48 | size = int(old_size*self.scale)
49 |
50 | # crop image
51 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
52 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]])
53 | tform = estimate_transform('similarity', src_pts, DST_PTS)
54 |
55 | image = image/255.
56 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp))
57 | dst_image = dst_image.transpose(2,0,1)
58 | return {'image': torch.tensor(dst_image).float(),
59 | 'imagename': imagename,
60 | # 'tform': tform,
61 | # 'original_image': torch.tensor(image.transpose(2,0,1)).float(),
62 | }
--------------------------------------------------------------------------------
/decalib/datasets/build_datasets.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | from torch.utils.data import Dataset, ConcatDataset
4 | import torchvision.transforms as transforms
5 | import numpy as np
6 | import cv2
7 | import scipy
8 | from skimage.io import imread, imsave
9 | from skimage.transform import estimate_transform, warp, resize, rescale
10 | from glob import glob
11 |
12 | from .vggface import VGGFace2Dataset
13 | from .ethnicity import EthnicityDataset
14 | from .aflw2000 import AFLW2000
15 | from .now import NoWDataset
16 | from .vox import VoxelDataset
17 |
18 | def build_train(config, is_train=True):
19 | data_list = []
20 | if 'vox2' in config.training_data:
21 | data_list.append(VoxelDataset(dataname='vox2', K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
22 | if 'vggface2' in config.training_data:
23 | data_list.append(VGGFace2Dataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
24 | if 'vggface2hq' in config.training_data:
25 | data_list.append(VGGFace2HQDataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
26 | if 'ethnicity' in config.training_data:
27 | data_list.append(EthnicityDataset(K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
28 | if 'coco' in config.training_data:
29 | data_list.append(COCODataset(image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale))
30 | if 'celebahq' in config.training_data:
31 | data_list.append(CelebAHQDataset(image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale))
32 | dataset = ConcatDataset(data_list)
33 |
34 | return dataset
35 |
36 | def build_val(config, is_train=True):
37 | data_list = []
38 | if 'vggface2' in config.eval_data:
39 | data_list.append(VGGFace2Dataset(isEval=True, K=config.K, image_size=config.image_size, scale=[config.scale_min, config.scale_max], trans_scale=config.trans_scale, isSingle=config.isSingle))
40 | if 'now' in config.eval_data:
41 | data_list.append(NoWDataset())
42 | if 'aflw2000' in config.eval_data:
43 | data_list.append(AFLW2000())
44 | dataset = ConcatDataset(data_list)
45 |
46 | return dataset
47 |
--------------------------------------------------------------------------------
/decalib/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import os, sys
17 | import torch
18 | from torch.utils.data import Dataset, DataLoader
19 | import torchvision.transforms as transforms
20 | import numpy as np
21 | import cv2
22 | import scipy
23 | from skimage.io import imread, imsave
24 | from skimage.transform import estimate_transform, warp, resize, rescale
25 | from glob import glob
26 | import scipy.io
27 | from decalib.datasets import datasets
28 | from torchvision.utils import save_image
29 |
30 | from . import detectors
31 |
32 | def video2sequence(video_path, sample_step=10):
33 | videofolder = os.path.splitext(video_path)[0]
34 | os.makedirs(videofolder, exist_ok=True)
35 | video_name = os.path.splitext(os.path.split(video_path)[-1])[0]
36 | vidcap = cv2.VideoCapture(video_path)
37 | success,image = vidcap.read()
38 | count = 0
39 | imagepath_list = []
40 | while success:
41 | # if count%sample_step == 0:
42 | imagepath = os.path.join(videofolder, f'{video_name}_frame{count:04d}.jpg')
43 | cv2.imwrite(imagepath, image) # save frame as JPEG file
44 | success,image = vidcap.read()
45 | count += 1
46 | imagepath_list.append(imagepath)
47 | print('video frames are stored in {}'.format(videofolder))
48 | return imagepath_list
49 |
50 | class TestData(Dataset):
51 | # 传递过来的有iscrop=True, size=256, sort=True
52 | def __init__(self, testpath, iscrop=True, crop_size=224, scale=1.25, face_detector='fan',
53 | sample_step=10, size=256, sort=False):
54 |
55 | if isinstance(testpath, list):
56 | self.imagepath_list = testpath
57 | elif os.path.isdir(testpath):
58 | self.imagepath_list = glob(testpath + '/*.jpg') + glob(testpath + '/*.png') + glob(testpath + '/*.bmp')
59 | elif os.path.isfile(testpath) and (testpath[-3:] in ['jpg', 'png', 'bmp']):
60 | self.imagepath_list = [testpath]
61 | elif os.path.isfile(testpath) and (testpath[-3:] in ['mp4', 'csv', 'vid', 'ebm']):
62 | self.imagepath_list = video2sequence(testpath, sample_step)
63 |
64 | if sort:
65 | self.imagepath_list = sorted(self.imagepath_list)
66 | self.crop_size = crop_size
67 | self.scale = scale
68 | self.iscrop = iscrop
69 | self.resolution_inp = crop_size
70 | self.size = size
71 | # 使用的是face alignment的关键点检测工具
72 | if face_detector == 'fan':
73 | self.face_detector = detectors.FAN()
74 | else:
75 | print(f'please check the detector: {face_detector}')
76 | exit()
77 |
78 | def __len__(self):
79 | return len(self.imagepath_list)
80 |
81 | def bbox2point(self, left, right, top, bottom, type='bbox'):
82 |
83 | if type=='kpt68':
84 | old_size = (right - left + bottom - top) / 2 * 1.1
85 | # 人脸中心点的位置
86 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])
87 | elif type=='bbox':
88 | old_size = (right - left + bottom - top)/2
89 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.12])
90 | else:
91 | raise NotImplementedError
92 | return old_size, center
93 |
94 | def get_image(self, image):
95 | h, w, _ = image.shape
96 | bbox, bbox_type = self.face_detector.run(image)
97 | if len(bbox) < 4:
98 | print('no face detected! run original image')
99 | left = 0; right = h-1; top=0; bottom=w-1
100 | else:
101 | left = bbox[0]; right=bbox[2]
102 | top = bbox[1]; bottom=bbox[3]
103 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type)
104 | size = int(old_size*self.scale)
105 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
106 |
107 | DST_PTS = np.array([[0,0], [0,self.resolution_inp - 1], [self.resolution_inp - 1, 0]])
108 | tform = estimate_transform('similarity', src_pts, DST_PTS)
109 |
110 | image = image / 255.
111 |
112 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp))
113 | dst_image = dst_image.transpose(2,0,1)
114 | return {'image': torch.tensor(dst_image).float(),
115 | 'tform': torch.tensor(tform.params).float(),
116 | 'original_image': torch.tensor(image.transpose(2,0,1)).float(),
117 | }
118 |
119 |
120 | def __getitem__(self, index):
121 |
122 | imagepath = self.imagepath_list[index]
123 | imagename = os.path.splitext(os.path.split(imagepath)[-1])[0]
124 | im = imread(imagepath)
125 |
126 | if self.size is not None: # size = 256
127 | im = (resize(im, (self.size, self.size), anti_aliasing=True) * 255.).astype(np.uint8)
128 |
129 | # (256, 256, 3)
130 | image = np.array(im)
131 |
132 | if len(image.shape) == 2:
133 | image = image[:, :, None].repeat(1,1,3)
134 | if len(image.shape) == 3 and image.shape[2] > 3:
135 | image = image[:, :, :3]
136 |
137 | h, w, _ = image.shape
138 | if self.iscrop: # true
139 | # provide kpt as txt file, or mat file (for AFLW2000)
140 | # 检查是否存在landmark的文件,不存在则自己检测
141 | kpt_matpath = os.path.splitext(imagepath)[0]+'.mat'
142 | kpt_txtpath = os.path.splitext(imagepath)[0]+'.txt'
143 | if os.path.exists(kpt_matpath):
144 | kpt = scipy.io.loadmat(kpt_matpath)['pt3d_68'].T
145 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0])
146 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
147 | old_size, center = self.bbox2point(left, right, top, bottom, type='kpt68')
148 | elif os.path.exists(kpt_txtpath):
149 | kpt = np.loadtxt(kpt_txtpath)
150 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0])
151 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
152 | old_size, center = self.bbox2point(left, right, top, bottom, type='kpt68')
153 | else:
154 | bbox, bbox_type = self.face_detector.run(image)
155 | if len(bbox) < 4:
156 | print('no face detected! run original image')
157 | left = 0; right = h-1; top=0; bottom=w-1
158 | else:
159 | left = bbox[0]; right=bbox[2]
160 | top = bbox[1]; bottom=bbox[3]
161 | old_size, center = self.bbox2point(left, right, top, bottom, type=bbox_type)
162 | size = int(old_size * self.scale)
163 | src_pts = np.array([[center[0] - size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
164 | else:
165 | src_pts = np.array([[0, 0], [0, h-1], [w-1, 0]])
166 | # DST_PTS = np.array([[0, 0], [0, h-1], [w-1, 0]])
167 | # self.resolution_inp = 224,目标图像大小
168 | DST_PTS = np.array([[0, 0], [0, self.resolution_inp - 1], [self.resolution_inp - 1, 0]])
169 | # 计算源图像变换到目标图像需要经过怎样的矩阵转换
170 | tform = estimate_transform('similarity', src_pts, DST_PTS)
171 |
172 | image = image / 255. # 0-1区间
173 |
174 | dst_image = warp(image, tform.inverse, output_shape=(self.resolution_inp, self.resolution_inp))
175 | dst_image = dst_image.transpose(2,0,1)
176 |
177 | return {'image': torch.tensor(dst_image).float(), # 只对面部区域进行了保留
178 | 'imagename': imagename,
179 | 'tform': torch.tensor(tform.params).float(),
180 | 'original_image': torch.tensor(image.transpose(2, 0, 1)).float(),
181 | }
182 |
183 |
184 |
185 |
186 |
187 | if __name__ == '__main__':
188 | testdata_source = datasets.TestData(
189 | source, iscrop=True, size=512, sort=True
190 | )
--------------------------------------------------------------------------------
/decalib/datasets/detectors.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import numpy as np
17 | import torch
18 |
19 | class FAN(object):
20 | def __init__(self):
21 | import face_alignment
22 | self.model = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
23 |
24 | def run(self, image):
25 | '''
26 | image: 0-255, uint8, rgb, [h, w, 3]
27 | return: detected box list
28 | '''
29 | out = self.model.get_landmarks(image)
30 | if out is None:
31 | return [0], 'kpt68'
32 | else:
33 | kpt = out[0].squeeze()
34 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0])
35 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
36 | bbox = [left, top, right, bottom]
37 | return bbox, 'kpt68'
38 |
39 | class MTCNN(object):
40 | def __init__(self, device = 'cpu'):
41 | '''
42 | https://github.com/timesler/facenet-pytorch/blob/master/examples/infer.ipynb
43 | '''
44 | from facenet_pytorch import MTCNN as mtcnn
45 | self.device = device
46 | self.model = mtcnn(keep_all=True)
47 | def run(self, input):
48 | '''
49 | image: 0-255, uint8, rgb, [h, w, 3]
50 | return: detected box
51 | '''
52 | out = self.model.detect(input[None,...])
53 | if out[0][0] is None:
54 | return [0]
55 | else:
56 | bbox = out[0][0].squeeze()
57 | return bbox, 'bbox'
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/decalib/datasets/ethnicity.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 |
12 | class EthnicityDataset(Dataset):
13 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False):
14 | '''
15 | K must be less than 6
16 | '''
17 | self.K = K
18 | self.image_size = image_size
19 | self.imagefolder = '/ps/scratch/face2d3d/train'
20 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7/'
21 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch/'
22 | # hq:
23 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy'
24 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_and_race_per_7000_african_asian_2d_train_list_max_normal_100_ring_5_1_serial.npy'
25 | self.data_lines = np.load(datafile).astype('str')
26 |
27 | self.isTemporal = isTemporal
28 | self.scale = scale #[scale_min, scale_max]
29 | self.trans_scale = trans_scale #[dx, dy]
30 | self.isSingle = isSingle
31 | if isSingle:
32 | self.K = 1
33 |
34 | def __len__(self):
35 | return len(self.data_lines)
36 |
37 | def __getitem__(self, idx):
38 | images_list = []; kpt_list = []; mask_list = []
39 | for i in range(self.K):
40 | name = self.data_lines[idx, i]
41 | if name[0]=='n':
42 | self.imagefolder = '/ps/scratch/face2d3d/train/'
43 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7/'
44 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch/'
45 | elif name[0]=='A':
46 | self.imagefolder = '/ps/scratch/face2d3d/race_per_7000/'
47 | self.kptfolder = '/ps/scratch/face2d3d/race_per_7000_annotated_torch7_new/'
48 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/race7000_seg/test_crop_size_400_batch/'
49 |
50 | image_path = os.path.join(self.imagefolder, name + '.jpg')
51 | seg_path = os.path.join(self.segfolder, name + '.npy')
52 | kpt_path = os.path.join(self.kptfolder, name + '.npy')
53 |
54 | image = imread(image_path)/255.
55 | kpt = np.load(kpt_path)[:,:2]
56 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1])
57 |
58 | ### crop information
59 | tform = self.crop(image, kpt)
60 | ## crop
61 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
62 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size))
63 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
64 |
65 | # normalized kpt
66 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1
67 |
68 | images_list.append(cropped_image.transpose(2,0,1))
69 | kpt_list.append(cropped_kpt)
70 | mask_list.append(cropped_mask)
71 |
72 | ###
73 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3
74 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3
75 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3
76 |
77 | if self.isSingle:
78 | images_array = images_array.squeeze()
79 | kpt_array = kpt_array.squeeze()
80 | mask_array = mask_array.squeeze()
81 |
82 | data_dict = {
83 | 'image': images_array,
84 | 'landmark': kpt_array,
85 | 'mask': mask_array
86 | }
87 |
88 | return data_dict
89 |
90 | def crop(self, image, kpt):
91 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]);
92 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
93 |
94 | h, w, _ = image.shape
95 | old_size = (right - left + bottom - top)/2
96 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
97 | # translate center
98 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale
99 | center = center + trans_scale*old_size # 0.5
100 |
101 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0]
102 | size = int(old_size*scale)
103 |
104 | # crop image
105 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
106 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]])
107 | tform = estimate_transform('similarity', src_pts, DST_PTS)
108 |
109 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
110 | # # change kpt accordingly
111 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
112 | return tform
113 |
114 | def load_mask(self, maskpath, h, w):
115 | # print(maskpath)
116 | if os.path.isfile(maskpath):
117 | vis_parsing_anno = np.load(maskpath)
118 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
119 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
120 | mask = np.zeros_like(vis_parsing_anno)
121 | # for i in range(1, 16):
122 | mask[vis_parsing_anno>0.5] = 1.
123 | else:
124 | mask = np.ones((h, w))
125 | return mask
126 |
127 |
--------------------------------------------------------------------------------
/decalib/datasets/now.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 |
12 | class NoWDataset(Dataset):
13 | def __init__(self, ring_elements=6, crop_size=224, scale=1.6):
14 | folder = '/ps/scratch/yfeng/other-github/now_evaluation/data/NoW_Dataset'
15 | self.data_path = os.path.join(folder, 'imagepathsvalidation.txt')
16 | with open(self.data_path) as f:
17 | self.data_lines = f.readlines()
18 |
19 | self.imagefolder = os.path.join(folder, 'final_release_version', 'iphone_pictures')
20 | self.bbxfolder = os.path.join(folder, 'final_release_version', 'detected_face')
21 |
22 | # self.data_path = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/test_image_paths_ring_6_elements.npy'
23 | # self.imagepath = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/iphone_pictures/'
24 | # self.bbxpath = '/ps/scratch/face2d3d/ringnetpp/eccv/test_data/evaluation/NoW_Dataset/final_release_version/detected_face/'
25 | self.crop_size = crop_size
26 | self.scale = scale
27 |
28 | def __len__(self):
29 | return len(self.data_lines)
30 |
31 | def __getitem__(self, index):
32 | imagepath = os.path.join(self.imagefolder, self.data_lines[index].strip()) #+ '.jpg'
33 | bbx_path = os.path.join(self.bbxfolder, self.data_lines[index].strip().replace('.jpg', '.npy'))
34 | bbx_data = np.load(bbx_path, allow_pickle=True, encoding='latin1').item()
35 | # box = np.array([[bbx_data['left'], bbx_data['top']], [bbx_data['right'], bbx_data['bottom']]]).astype('float32')
36 | left = bbx_data['left']; right = bbx_data['right']
37 | top = bbx_data['top']; bottom = bbx_data['bottom']
38 |
39 | imagename = imagepath.split('/')[-1].split('.')[0]
40 | image = imread(imagepath)[:,:,:3]
41 |
42 | h, w, _ = image.shape
43 | old_size = (right - left + bottom - top)/2
44 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])
45 | size = int(old_size*self.scale)
46 |
47 | # crop image
48 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
49 | DST_PTS = np.array([[0,0], [0,self.crop_size - 1], [self.crop_size - 1, 0]])
50 | tform = estimate_transform('similarity', src_pts, DST_PTS)
51 |
52 | image = image/255.
53 | dst_image = warp(image, tform.inverse, output_shape=(self.crop_size, self.crop_size))
54 | dst_image = dst_image.transpose(2,0,1)
55 | return {'image': torch.tensor(dst_image).float(),
56 | 'imagename': self.data_lines[index].strip().replace('.jpg', ''),
57 | # 'tform': tform,
58 | # 'original_image': torch.tensor(image.transpose(2,0,1)).float(),
59 | }
--------------------------------------------------------------------------------
/decalib/datasets/vggface.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 |
12 | class VGGFace2Dataset(Dataset):
13 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False):
14 | '''
15 | K must be less than 6
16 | '''
17 | self.K = K
18 | self.image_size = image_size
19 | self.imagefolder = '/ps/scratch/face2d3d/train'
20 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7'
21 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch'
22 | # hq:
23 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy'
24 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_train_list_max_normal_100_ring_5_1_serial.npy'
25 | if isEval:
26 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_val_list_max_normal_100_ring_5_1_serial.npy'
27 | self.data_lines = np.load(datafile).astype('str')
28 |
29 | self.isTemporal = isTemporal
30 | self.scale = scale #[scale_min, scale_max]
31 | self.trans_scale = trans_scale #[dx, dy]
32 | self.isSingle = isSingle
33 | if isSingle:
34 | self.K = 1
35 |
36 | def __len__(self):
37 | return len(self.data_lines)
38 |
39 | def __getitem__(self, idx):
40 | images_list = []; kpt_list = []; mask_list = []
41 |
42 | random_ind = np.random.permutation(5)[:self.K]
43 | for i in random_ind:
44 | name = self.data_lines[idx, i]
45 | image_path = os.path.join(self.imagefolder, name + '.jpg')
46 | seg_path = os.path.join(self.segfolder, name + '.npy')
47 | kpt_path = os.path.join(self.kptfolder, name + '.npy')
48 |
49 | image = imread(image_path)/255.
50 | kpt = np.load(kpt_path)[:,:2]
51 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1])
52 |
53 | ### crop information
54 | tform = self.crop(image, kpt)
55 | ## crop
56 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
57 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size))
58 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
59 |
60 | # normalized kpt
61 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1
62 |
63 | images_list.append(cropped_image.transpose(2,0,1))
64 | kpt_list.append(cropped_kpt)
65 | mask_list.append(cropped_mask)
66 |
67 | ###
68 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3
69 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3
70 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3
71 |
72 | if self.isSingle:
73 | images_array = images_array.squeeze()
74 | kpt_array = kpt_array.squeeze()
75 | mask_array = mask_array.squeeze()
76 |
77 | data_dict = {
78 | 'image': images_array,
79 | 'landmark': kpt_array,
80 | 'mask': mask_array
81 | }
82 |
83 | return data_dict
84 |
85 | def crop(self, image, kpt):
86 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]);
87 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
88 |
89 | h, w, _ = image.shape
90 | old_size = (right - left + bottom - top)/2
91 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
92 | # translate center
93 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale
94 | center = center + trans_scale*old_size # 0.5
95 |
96 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0]
97 | size = int(old_size*scale)
98 |
99 | # crop image
100 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
101 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]])
102 | tform = estimate_transform('similarity', src_pts, DST_PTS)
103 |
104 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
105 | # # change kpt accordingly
106 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
107 | return tform
108 |
109 | def load_mask(self, maskpath, h, w):
110 | # print(maskpath)
111 | if os.path.isfile(maskpath):
112 | vis_parsing_anno = np.load(maskpath)
113 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
114 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
115 | mask = np.zeros_like(vis_parsing_anno)
116 | # for i in range(1, 16):
117 | mask[vis_parsing_anno>0.5] = 1.
118 | else:
119 | mask = np.ones((h, w))
120 | return mask
121 |
122 |
123 |
124 | class VGGFace2HQDataset(Dataset):
125 | def __init__(self, K, image_size, scale, trans_scale = 0, isTemporal=False, isEval=False, isSingle=False):
126 | '''
127 | K must be less than 6
128 | '''
129 | self.K = K
130 | self.image_size = image_size
131 | self.imagefolder = '/ps/scratch/face2d3d/train'
132 | self.kptfolder = '/ps/scratch/face2d3d/train_annotated_torch7'
133 | self.segfolder = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_seg/test_crop_size_400_batch'
134 | # hq:
135 | # datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy'
136 | datafile = '/ps/scratch/face2d3d/texture_in_the_wild_code/VGGFace2_cleaning_codes/ringnetpp_training_lists/second_cleaning/vggface2_bbx_size_bigger_than_400_train_list_max_normal_100_ring_5_1_serial.npy'
137 | self.data_lines = np.load(datafile).astype('str')
138 |
139 | self.isTemporal = isTemporal
140 | self.scale = scale #[scale_min, scale_max]
141 | self.trans_scale = trans_scale #[dx, dy]
142 | self.isSingle = isSingle
143 | if isSingle:
144 | self.K = 1
145 |
146 | def __len__(self):
147 | return len(self.data_lines)
148 |
149 | def __getitem__(self, idx):
150 | images_list = []; kpt_list = []; mask_list = []
151 |
152 | for i in range(self.K):
153 | name = self.data_lines[idx, i]
154 | image_path = os.path.join(self.imagefolder, name + '.jpg')
155 | seg_path = os.path.join(self.segfolder, name + '.npy')
156 | kpt_path = os.path.join(self.kptfolder, name + '.npy')
157 |
158 | image = imread(image_path)/255.
159 | kpt = np.load(kpt_path)[:,:2]
160 | mask = self.load_mask(seg_path, image.shape[0], image.shape[1])
161 |
162 | ### crop information
163 | tform = self.crop(image, kpt)
164 | ## crop
165 | cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
166 | cropped_mask = warp(mask, tform.inverse, output_shape=(self.image_size, self.image_size))
167 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
168 |
169 | # normalized kpt
170 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1
171 |
172 | images_list.append(cropped_image.transpose(2,0,1))
173 | kpt_list.append(cropped_kpt)
174 | mask_list.append(cropped_mask)
175 |
176 | ###
177 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3
178 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3
179 | mask_array = torch.from_numpy(np.array(mask_list)).type(dtype = torch.float32) #K,224,224,3
180 |
181 | if self.isSingle:
182 | images_array = images_array.squeeze()
183 | kpt_array = kpt_array.squeeze()
184 | mask_array = mask_array.squeeze()
185 |
186 | data_dict = {
187 | 'image': images_array,
188 | 'landmark': kpt_array,
189 | 'mask': mask_array
190 | }
191 |
192 | return data_dict
193 |
194 | def crop(self, image, kpt):
195 | left = np.min(kpt[:,0]); right = np.max(kpt[:,0]);
196 | top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])
197 |
198 | h, w, _ = image.shape
199 | old_size = (right - left + bottom - top)/2
200 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
201 | # translate center
202 | trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale
203 | center = center + trans_scale*old_size # 0.5
204 |
205 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0]
206 | size = int(old_size*scale)
207 |
208 | # crop image
209 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
210 | DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]])
211 | tform = estimate_transform('similarity', src_pts, DST_PTS)
212 |
213 | # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
214 | # # change kpt accordingly
215 | # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
216 | return tform
217 |
218 | def load_mask(self, maskpath, h, w):
219 | # print(maskpath)
220 | if os.path.isfile(maskpath):
221 | vis_parsing_anno = np.load(maskpath)
222 | # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
223 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
224 | mask = np.zeros_like(vis_parsing_anno)
225 | # for i in range(1, 16):
226 | mask[vis_parsing_anno>0.5] = 1.
227 | else:
228 | mask = np.ones((h, w))
229 | return mask
--------------------------------------------------------------------------------
/decalib/datasets/vox.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 | import scipy
7 | from skimage.io import imread, imsave
8 | from skimage.transform import estimate_transform, warp, resize, rescale
9 | from glob import glob
10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
11 |
12 | class VoxelDataset(Dataset):
13 | def __init__(self, K, image_size, scale, trans_scale = 0, dataname='vox2', n_train=100000, isTemporal=False, isEval=False, isSingle=False):
14 | self.K = K
15 | self.image_size = image_size
16 | if dataname == 'vox1':
17 | self.kpt_suffix = '.txt'
18 | self.imagefolder = '/ps/project/face2d3d/VoxCeleb/vox1/dev/images_cropped'
19 | self.kptfolder = '/ps/scratch/yfeng/Data/VoxCeleb/vox1/landmark_2d'
20 |
21 | self.face_dict = {}
22 | for person_id in sorted(os.listdir(self.kptfolder)):
23 | for video_id in os.listdir(os.path.join(self.kptfolder, person_id)):
24 | for face_id in os.listdir(os.path.join(self.kptfolder, person_id, video_id)):
25 | if 'txt' in face_id:
26 | continue
27 | key = person_id + '/' + video_id + '/' + face_id
28 | # if key not in self.face_dict.keys():
29 | # self.face_dict[key] = []
30 | name_list = os.listdir(os.path.join(self.kptfolder, person_id, video_id, face_id))
31 | name_list = [name.split['.'][0] for name in name_list]
32 | if len(name_list)0.5] = 1.
162 | else:
163 | mask = np.ones((h, w))
164 | return mask
165 |
--------------------------------------------------------------------------------
/decalib/models/FLAME.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import torch
17 | import torch.nn as nn
18 | import numpy as np
19 | import pickle
20 | import torch.nn.functional as F
21 |
22 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler
23 |
24 | def to_tensor(array, dtype=torch.float32):
25 | if 'torch.tensor' not in str(type(array)):
26 | return torch.tensor(array, dtype=dtype)
27 | def to_np(array, dtype=np.float32):
28 | if 'scipy.sparse' in str(type(array)):
29 | array = array.todense()
30 | return np.array(array, dtype=dtype)
31 |
32 | class Struct(object):
33 | def __init__(self, **kwargs):
34 | for key, val in kwargs.items():
35 | setattr(self, key, val)
36 |
37 | class FLAME(nn.Module):
38 | """
39 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py
40 | Given flame parameters this class generates a differentiable FLAME function
41 | which outputs the a mesh and 2D/3D facial landmarks
42 | """
43 | def __init__(self, config):
44 | super(FLAME, self).__init__()
45 | print("creating the FLAME Decoder")
46 | with open(config.flame_model_path, 'rb') as f:
47 | ss = pickle.load(f, encoding='latin1')
48 | flame_model = Struct(**ss)
49 |
50 | self.dtype = torch.float32
51 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long))
52 | # The vertices of the template model
53 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype))
54 | # The shape components and expression
55 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
56 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2)
57 | self.register_buffer('shapedirs', shapedirs)
58 | # The pose components
59 | num_pose_basis = flame_model.posedirs.shape[-1]
60 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
61 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype))
62 | #
63 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype))
64 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1
65 | self.register_buffer('parents', parents)
66 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype))
67 |
68 | # Fixing Eyeball and neck rotation
69 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False)
70 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose,
71 | requires_grad=False))
72 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False)
73 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose,
74 | requires_grad=False))
75 |
76 | # Static and Dynamic Landmark embeddings for FLAME
77 | lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1')
78 | lmk_embeddings = lmk_embeddings[()]
79 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long())
80 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype))
81 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long())
82 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype))
83 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long())
84 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype))
85 |
86 | neck_kin_chain = []; NECK_IDX=1
87 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
88 | while curr_idx != -1:
89 | neck_kin_chain.append(curr_idx)
90 | curr_idx = self.parents[curr_idx]
91 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain))
92 |
93 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx,
94 | dynamic_lmk_b_coords,
95 | neck_kin_chain, dtype=torch.float32):
96 | """
97 | Selects the face contour depending on the reletive position of the head
98 | Input:
99 | vertices: N X num_of_vertices X 3
100 | pose: N X full pose
101 | dynamic_lmk_faces_idx: The list of contour face indexes
102 | dynamic_lmk_b_coords: The list of contour barycentric weights
103 | neck_kin_chain: The tree to consider for the relative rotation
104 | dtype: Data type
105 | return:
106 | The contour face indexes and the corresponding barycentric weights
107 | """
108 |
109 | batch_size = pose.shape[0]
110 |
111 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
112 | neck_kin_chain)
113 | rot_mats = batch_rodrigues(
114 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
115 |
116 | rel_rot_mat = torch.eye(3, device=pose.device,
117 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1)
118 | for idx in range(len(neck_kin_chain)):
119 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
120 |
121 | y_rot_angle = torch.round(
122 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
123 | max=39)).to(dtype=torch.long)
124 |
125 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
126 | mask = y_rot_angle.lt(-39).to(dtype=torch.long)
127 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
128 | y_rot_angle = (neg_mask * neg_vals +
129 | (1 - neg_mask) * y_rot_angle)
130 |
131 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
132 | 0, y_rot_angle)
133 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
134 | 0, y_rot_angle)
135 | return dyn_lmk_faces_idx, dyn_lmk_b_coords
136 |
137 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords):
138 | """
139 | Calculates landmarks by barycentric interpolation
140 | Input:
141 | vertices: torch.tensor NxVx3, dtype = torch.float32
142 | The tensor of input vertices
143 | faces: torch.tensor (N*F)x3, dtype = torch.long
144 | The faces of the mesh
145 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long
146 | The tensor with the indices of the faces used to calculate the
147 | landmarks.
148 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32
149 | The tensor of barycentric coordinates that are used to interpolate
150 | the landmarks
151 |
152 | Returns:
153 | landmarks: torch.tensor NxLx3, dtype = torch.float32
154 | The coordinates of the landmarks for each mesh in the batch
155 | """
156 | # Extract the indices of the vertices for each face
157 | # NxLx3
158 | batch_size, num_verts = vertices.shape[:dd2]
159 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
160 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1)
161 |
162 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to(
163 | device=vertices.device) * num_verts
164 |
165 | lmk_vertices = vertices.view(-1, 3)[lmk_faces]
166 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
167 | return landmarks
168 |
169 | def seletec_3d68(self, vertices):
170 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
171 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1),
172 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
173 | return landmarks3d
174 |
175 | def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None):
176 | """
177 | Input:
178 | shape_params: N X number of shape parameters
179 | expression_params: N X number of expression parameters
180 | pose_params: N X number of pose parameters (6)
181 | return:d
182 | vertices: N X V X 3
183 | landmarks: N X number of landmarks X 3
184 | """
185 | batch_size = shape_params.shape[0]
186 | if pose_params is None:
187 | pose_params = self.eye_pose.expand(batch_size, -1)
188 | if eye_pose_params is None:
189 | eye_pose_params = self.eye_pose.expand(batch_size, -1)
190 | betas = torch.cat([shape_params, expression_params], dim=1)
191 | full_pose = torch.cat([pose_params[:, :3], self.neck_pose.expand(batch_size, -1), pose_params[:, 3:], eye_pose_params], dim=1)
192 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
193 |
194 | vertices, _ = lbs(betas, full_pose, template_vertices,
195 | self.shapedirs, self.posedirs,
196 | self.J_regressor, self.parents,
197 | self.lbs_weights, dtype=self.dtype)
198 |
199 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
200 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
201 |
202 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
203 | full_pose, self.dynamic_lmk_faces_idx,
204 | self.dynamic_lmk_bary_coords,
205 | self.neck_kin_chain, dtype=self.dtype)
206 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
207 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1)
208 |
209 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor,
210 | lmk_faces_idx,
211 | lmk_bary_coords)
212 | bz = vertices.shape[0]
213 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
214 | self.full_lmk_faces_idx.repeat(bz, 1),
215 | self.full_lmk_bary_coords.repeat(bz, 1, 1))
216 | return vertices, landmarks2d, landmarks3d
217 |
218 | class FLAMETex(nn.Module):
219 | """
220 | FLAME texture:
221 | https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64
222 | FLAME texture converted from BFM:
223 | https://github.com/TimoBolkart/BFM_to_FLAME
224 | """
225 | def __init__(self, config):
226 | super(FLAMETex, self).__init__()
227 | if config.tex_type == 'BFM':
228 | mu_key = 'MU'
229 | pc_key = 'PC'
230 | n_pc = 199
231 | tex_path = config.tex_path
232 | tex_space = np.load(tex_path)
233 | texture_mean = tex_space[mu_key].reshape(1, -1)
234 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)
235 |
236 | elif config.tex_type == 'FLAME':
237 | mu_key = 'mean'
238 | pc_key = 'tex_dir'
239 | n_pc = 200
240 | tex_path = config.tex_path # config.flame_tex_path
241 | tex_space = np.load(tex_path)
242 | texture_mean = tex_space[mu_key].reshape(1, -1)/255.
243 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)/255.
244 | else:
245 | print('texture type ', config.tex_type, 'not exist!')
246 | raise NotImplementedError
247 |
248 | n_tex = config.n_tex
249 | num_components = texture_basis.shape[1]
250 | texture_mean = torch.from_numpy(texture_mean).float()[None,...]
251 | texture_basis = torch.from_numpy(texture_basis[:,:n_tex]).float()[None,...]
252 | self.register_buffer('texture_mean', texture_mean)
253 | self.register_buffer('texture_basis', texture_basis)
254 |
255 | def forward(self, texcode):
256 | '''
257 | texcode: [batchsize, n_tex]
258 | texture: [bz, 3, 256, 256], range: 0-1
259 | '''
260 | texture = self.texture_mean + (self.texture_basis*texcode[:,None,:]).sum(-1)
261 | texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0,3,1,2)
262 | texture = F.interpolate(texture, [256, 256])
263 | texture = texture[:,[2,1,0], :,:]
264 | return texture
265 |
--------------------------------------------------------------------------------
/decalib/models/__pycache__/FLAME.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/FLAME.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/models/__pycache__/decoders.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/decoders.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/models/__pycache__/encoders.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/encoders.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/models/__pycache__/lbs.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/lbs.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/models/__pycache__/resnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/models/__pycache__/resnet.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/models/decoders.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 | class Generator(nn.Module):
20 | def __init__(self, latent_dim=100, out_channels=1, out_scale=0.01, sample_mode = 'bilinear'):
21 | super(Generator, self).__init__()
22 | self.out_scale = out_scale
23 |
24 | self.init_size = 32 // 4 # Initial size before upsampling
25 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))
26 | self.conv_blocks = nn.Sequential(
27 | nn.BatchNorm2d(128),
28 | nn.Upsample(scale_factor=2, mode=sample_mode), #16
29 | nn.Conv2d(128, 128, 3, stride=1, padding=1),
30 | nn.BatchNorm2d(128, 0.8),
31 | nn.LeakyReLU(0.2, inplace=True),
32 | nn.Upsample(scale_factor=2, mode=sample_mode), #32
33 | nn.Conv2d(128, 64, 3, stride=1, padding=1),
34 | nn.BatchNorm2d(64, 0.8),
35 | nn.LeakyReLU(0.2, inplace=True),
36 | nn.Upsample(scale_factor=2, mode=sample_mode), #64
37 | nn.Conv2d(64, 64, 3, stride=1, padding=1),
38 | nn.BatchNorm2d(64, 0.8),
39 | nn.LeakyReLU(0.2, inplace=True),
40 | nn.Upsample(scale_factor=2, mode=sample_mode), #128
41 | nn.Conv2d(64, 32, 3, stride=1, padding=1),
42 | nn.BatchNorm2d(32, 0.8),
43 | nn.LeakyReLU(0.2, inplace=True),
44 | nn.Upsample(scale_factor=2, mode=sample_mode), #256
45 | nn.Conv2d(32, 16, 3, stride=1, padding=1),
46 | nn.BatchNorm2d(16, 0.8),
47 | nn.LeakyReLU(0.2, inplace=True),
48 | nn.Conv2d(16, out_channels, 3, stride=1, padding=1),
49 | nn.Tanh(),
50 | )
51 |
52 | def forward(self, noise):
53 | out = self.l1(noise)
54 | out = out.view(out.shape[0], 128, self.init_size, self.init_size)
55 | img = self.conv_blocks(out)
56 | return img*self.out_scale
--------------------------------------------------------------------------------
/decalib/models/encoders.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import numpy as np
17 | import torch.nn as nn
18 | import torch
19 | import torch.nn.functional as F
20 | from . import resnet
21 |
22 | class ResnetEncoder(nn.Module):
23 | def __init__(self, outsize, last_op=None):
24 | super(ResnetEncoder, self).__init__()
25 | feature_size = 2048
26 | self.encoder = resnet.load_ResNet50Model() #out: 2048
27 | ### regressor
28 | self.layers = nn.Sequential(
29 | nn.Linear(feature_size, 1024),
30 | nn.ReLU(),
31 | nn.Linear(1024, outsize)
32 | )
33 | self.last_op = last_op
34 |
35 | def forward(self, inputs):
36 | features = self.encoder(inputs)
37 | parameters = self.layers(features)
38 | if self.last_op:
39 | parameters = self.last_op(parameters)
40 | return parameters
41 |
--------------------------------------------------------------------------------
/decalib/models/frnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | # from pro_gan_pytorch.PRO_GAN import ProGAN, Generator, Discriminator
5 | import torch.nn.functional as F
6 | import cv2
7 | from torch.autograd import Variable
8 | import math
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None):
19 | super(BasicBlock, self).__init__()
20 | self.conv1 = conv3x3(inplanes, planes, stride)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.relu = nn.ReLU(inplace=True)
23 | self.conv2 = conv3x3(planes, planes)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 | self.downsample = downsample
26 | self.stride = stride
27 |
28 | def forward(self, x):
29 | residual = x
30 |
31 | out = self.conv1(x)
32 | out = self.bn1(out)
33 | out = self.relu(out)
34 |
35 | out = self.conv2(out)
36 | out = self.bn2(out)
37 |
38 | if self.downsample is not None:
39 | residual = self.downsample(x)
40 |
41 | out += residual
42 | out = self.relu(out)
43 |
44 | return out
45 |
46 |
47 | class Bottleneck(nn.Module):
48 | expansion = 4
49 |
50 | def __init__(self, inplanes, planes, stride=1, downsample=None):
51 | super(Bottleneck, self).__init__()
52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
53 | self.bn1 = nn.BatchNorm2d(planes)
54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
55 | self.bn2 = nn.BatchNorm2d(planes)
56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
57 | self.bn3 = nn.BatchNorm2d(planes * 4)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 |
76 | if self.downsample is not None:
77 | residual = self.downsample(x)
78 |
79 | out += residual
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class ResNet(nn.Module):
86 |
87 | def __init__(self, block, layers, num_classes=1000, include_top=True):
88 | self.inplanes = 64
89 | super(ResNet, self).__init__()
90 | self.include_top = include_top
91 |
92 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
96 |
97 | self.layer1 = self._make_layer(block, 64, layers[0])
98 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
99 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
100 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
101 | self.avgpool = nn.AvgPool2d(7, stride=1)
102 | self.fc = nn.Linear(512 * block.expansion, num_classes)
103 |
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
107 | m.weight.data.normal_(0, math.sqrt(2. / n))
108 | elif isinstance(m, nn.BatchNorm2d):
109 | m.weight.data.fill_(1)
110 | m.bias.data.zero_()
111 |
112 | def _make_layer(self, block, planes, blocks, stride=1):
113 | downsample = None
114 | if stride != 1 or self.inplanes != planes * block.expansion:
115 | downsample = nn.Sequential(
116 | nn.Conv2d(self.inplanes, planes * block.expansion,
117 | kernel_size=1, stride=stride, bias=False),
118 | nn.BatchNorm2d(planes * block.expansion),
119 | )
120 |
121 | layers = []
122 | layers.append(block(self.inplanes, planes, stride, downsample))
123 | self.inplanes = planes * block.expansion
124 | for i in range(1, blocks):
125 | layers.append(block(self.inplanes, planes))
126 |
127 | return nn.Sequential(*layers)
128 |
129 | def forward(self, x):
130 | x = self.conv1(x)
131 | x = self.bn1(x)
132 | x = self.relu(x)
133 | x = self.maxpool(x)
134 |
135 | x = self.layer1(x)
136 | x = self.layer2(x)
137 | x = self.layer3(x)
138 | x = self.layer4(x)
139 |
140 | x = self.avgpool(x)
141 |
142 | if not self.include_top:
143 | return x
144 |
145 | x = x.view(x.size(0), -1)
146 | x = self.fc(x)
147 | return x
148 |
149 | def resnet50(**kwargs):
150 | """Constructs a ResNet-50 model.
151 | """
152 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
153 | return model
154 |
155 | import pickle
156 | def load_state_dict(model, fname):
157 | """
158 | Set parameters converted from Caffe models authors of VGGFace2 provide.
159 | See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
160 | Arguments:
161 | model: model
162 | fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
163 | """
164 | with open(fname, 'rb') as f:
165 | weights = pickle.load(f, encoding='latin1')
166 |
167 | own_state = model.state_dict()
168 | for name, param in weights.items():
169 | if name in own_state:
170 | try:
171 | own_state[name].copy_(torch.from_numpy(param))
172 | except Exception:
173 | raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
174 | 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
175 | else:
176 | raise KeyError('unexpected key "{}" in state_dict'.format(name))
177 |
178 |
--------------------------------------------------------------------------------
/decalib/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Soubhik Sanyal
3 | Copyright (c) 2019, Soubhik Sanyal
4 | All rights reserved.
5 | Loads different resnet models
6 | """
7 | '''
8 | file: Resnet.py
9 | date: 2018_05_02
10 | author: zhangxiong(1025679612@qq.com)
11 | mark: copied from pytorch source code
12 | '''
13 |
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch
17 | from torch.nn.parameter import Parameter
18 | import torch.optim as optim
19 | import numpy as np
20 | import math
21 | import torchvision
22 |
23 | class ResNet(nn.Module):
24 | def __init__(self, block, layers, num_classes=1000):
25 | self.inplanes = 64
26 | super(ResNet, self).__init__()
27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
28 | bias=False)
29 | self.bn1 = nn.BatchNorm2d(64)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
32 | self.layer1 = self._make_layer(block, 64, layers[0])
33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
36 | self.avgpool = nn.AvgPool2d(7, stride=1)
37 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
38 |
39 | for m in self.modules():
40 | if isinstance(m, nn.Conv2d):
41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
42 | m.weight.data.normal_(0, math.sqrt(2. / n))
43 | elif isinstance(m, nn.BatchNorm2d):
44 | m.weight.data.fill_(1)
45 | m.bias.data.zero_()
46 |
47 | def _make_layer(self, block, planes, blocks, stride=1):
48 | downsample = None
49 | if stride != 1 or self.inplanes != planes * block.expansion:
50 | downsample = nn.Sequential(
51 | nn.Conv2d(self.inplanes, planes * block.expansion,
52 | kernel_size=1, stride=stride, bias=False),
53 | nn.BatchNorm2d(planes * block.expansion),
54 | )
55 |
56 | layers = []
57 | layers.append(block(self.inplanes, planes, stride, downsample))
58 | self.inplanes = planes * block.expansion
59 | for i in range(1, blocks):
60 | layers.append(block(self.inplanes, planes))
61 |
62 | return nn.Sequential(*layers)
63 |
64 | def forward(self, x):
65 | x = self.conv1(x)
66 | x = self.bn1(x)
67 | x = self.relu(x)
68 | x = self.maxpool(x)
69 |
70 | x = self.layer1(x)
71 | x = self.layer2(x)
72 | x = self.layer3(x)
73 | x1 = self.layer4(x)
74 |
75 | x2 = self.avgpool(x1)
76 | x2 = x2.view(x2.size(0), -1)
77 | # x = self.fc(x)
78 | ## x2: [bz, 2048] for shape
79 | ## x1: [bz, 2048, 7, 7] for texture
80 | return x2
81 |
82 | class Bottleneck(nn.Module):
83 | expansion = 4
84 |
85 | def __init__(self, inplanes, planes, stride=1, downsample=None):
86 | super(Bottleneck, self).__init__()
87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
88 | self.bn1 = nn.BatchNorm2d(planes)
89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
90 | padding=1, bias=False)
91 | self.bn2 = nn.BatchNorm2d(planes)
92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
93 | self.bn3 = nn.BatchNorm2d(planes * 4)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.downsample = downsample
96 | self.stride = stride
97 |
98 | def forward(self, x):
99 | residual = x
100 |
101 | out = self.conv1(x)
102 | out = self.bn1(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv2(out)
106 | out = self.bn2(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv3(out)
110 | out = self.bn3(out)
111 |
112 | if self.downsample is not None:
113 | residual = self.downsample(x)
114 |
115 | out += residual
116 | out = self.relu(out)
117 |
118 | return out
119 |
120 | def conv3x3(in_planes, out_planes, stride=1):
121 | """3x3 convolution with padding"""
122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
123 | padding=1, bias=False)
124 |
125 | class BasicBlock(nn.Module):
126 | expansion = 1
127 |
128 | def __init__(self, inplanes, planes, stride=1, downsample=None):
129 | super(BasicBlock, self).__init__()
130 | self.conv1 = conv3x3(inplanes, planes, stride)
131 | self.bn1 = nn.BatchNorm2d(planes)
132 | self.relu = nn.ReLU(inplace=True)
133 | self.conv2 = conv3x3(planes, planes)
134 | self.bn2 = nn.BatchNorm2d(planes)
135 | self.downsample = downsample
136 | self.stride = stride
137 |
138 | def forward(self, x):
139 | residual = x
140 |
141 | out = self.conv1(x)
142 | out = self.bn1(out)
143 | out = self.relu(out)
144 |
145 | out = self.conv2(out)
146 | out = self.bn2(out)
147 |
148 | if self.downsample is not None:
149 | residual = self.downsample(x)
150 |
151 | out += residual
152 | out = self.relu(out)
153 |
154 | return out
155 |
156 | def copy_parameter_from_resnet(model, resnet_dict):
157 | cur_state_dict = model.state_dict()
158 | # import ipdb; ipdb.set_trace()
159 | for name, param in list(resnet_dict.items())[0:None]:
160 | if name not in cur_state_dict:
161 | # print(name, ' not available in reconstructed resnet')
162 | continue
163 | if isinstance(param, Parameter):
164 | param = param.data
165 | try:
166 | cur_state_dict[name].copy_(param)
167 | except:
168 | # print(name, ' is inconsistent!')
169 | continue
170 | # print('copy resnet state dict finished!')
171 | # import ipdb; ipdb.set_trace()
172 |
173 | def load_ResNet50Model():
174 | model = ResNet(Bottleneck, [3, 4, 6, 3])
175 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = False).state_dict())
176 | return model
177 |
178 | def load_ResNet101Model():
179 | model = ResNet(Bottleneck, [3, 4, 23, 3])
180 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict())
181 | return model
182 |
183 | def load_ResNet152Model():
184 | model = ResNet(Bottleneck, [3, 8, 36, 3])
185 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict())
186 | return model
187 |
188 | # model.load_state_dict(checkpoint['model_state_dict'])
189 |
190 |
191 | ######## Unet
192 |
193 | class DoubleConv(nn.Module):
194 | """(convolution => [BN] => ReLU) * 2"""
195 |
196 | def __init__(self, in_channels, out_channels):
197 | super().__init__()
198 | self.double_conv = nn.Sequential(
199 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
200 | nn.BatchNorm2d(out_channels),
201 | nn.ReLU(inplace=True),
202 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
203 | nn.BatchNorm2d(out_channels),
204 | nn.ReLU(inplace=True)
205 | )
206 |
207 | def forward(self, x):
208 | return self.double_conv(x)
209 |
210 |
211 | class Down(nn.Module):
212 | """Downscaling with maxpool then double conv"""
213 |
214 | def __init__(self, in_channels, out_channels):
215 | super().__init__()
216 | self.maxpool_conv = nn.Sequential(
217 | nn.MaxPool2d(2),
218 | DoubleConv(in_channels, out_channels)
219 | )
220 |
221 | def forward(self, x):
222 | return self.maxpool_conv(x)
223 |
224 |
225 | class Up(nn.Module):
226 | """Upscaling then double conv"""
227 |
228 | def __init__(self, in_channels, out_channels, bilinear=True):
229 | super().__init__()
230 |
231 | # if bilinear, use the normal convolutions to reduce the number of channels
232 | if bilinear:
233 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
234 | else:
235 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
236 |
237 | self.conv = DoubleConv(in_channels, out_channels)
238 |
239 | def forward(self, x1, x2):
240 | x1 = self.up(x1)
241 | # input is CHW
242 | diffY = x2.size()[2] - x1.size()[2]
243 | diffX = x2.size()[3] - x1.size()[3]
244 |
245 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
246 | diffY // 2, diffY - diffY // 2])
247 | # if you have padding issues, see
248 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
249 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
250 | x = torch.cat([x2, x1], dim=1)
251 | return self.conv(x)
252 |
253 |
254 | class OutConv(nn.Module):
255 | def __init__(self, in_channels, out_channels):
256 | super(OutConv, self).__init__()
257 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
258 |
259 | def forward(self, x):
260 | return self.conv(x)
--------------------------------------------------------------------------------
/decalib/smirk/__pycache__/mediapipe_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/smirk/__pycache__/mediapipe_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/smirk/__pycache__/smirk_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/smirk/__pycache__/smirk_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/smirk/face_landmarker.task:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/smirk/face_landmarker.task
--------------------------------------------------------------------------------
/decalib/smirk/mediapipe_utils.py:
--------------------------------------------------------------------------------
1 | import mediapipe as mp
2 | from mediapipe.tasks import python
3 | from mediapipe.tasks.python import vision
4 | import cv2
5 | import numpy as np
6 |
7 | base_options = python.BaseOptions(model_asset_path='/home/mengting/projects/diffusionRig/decalib/smirk/face_landmarker.task')
8 | options = vision.FaceLandmarkerOptions(base_options=base_options,
9 | output_face_blendshapes=True,
10 | output_facial_transformation_matrixes=True,
11 | num_faces=1,
12 | min_face_detection_confidence=0.1,
13 | min_face_presence_confidence=0.1
14 | )
15 | detector = vision.FaceLandmarker.create_from_options(options)
16 |
17 |
18 | def run_mediapipe(image):
19 | # print(image.shape)
20 | image_numpy = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
21 |
22 | # STEP 3: Load the input image.
23 | image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_numpy)
24 |
25 |
26 | # STEP 4: Detect face landmarks from the input image.
27 | detection_result = detector.detect(image)
28 |
29 | if len (detection_result.face_landmarks) == 0:
30 | print('No face detected')
31 | return None
32 |
33 | face_landmarks = detection_result.face_landmarks[0]
34 |
35 | face_landmarks_numpy = np.zeros((478, 3))
36 |
37 | for i, landmark in enumerate(face_landmarks):
38 | face_landmarks_numpy[i] = [landmark.x*image.width, landmark.y*image.height, landmark.z]
39 |
40 | return face_landmarks_numpy
41 |
--------------------------------------------------------------------------------
/decalib/smirk/smirk_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | import timm
5 |
6 |
7 | def create_backbone(backbone_name, pretrained=True):
8 | backbone = timm.create_model(backbone_name,
9 | pretrained=pretrained,
10 | features_only=True)
11 | feature_dim = backbone.feature_info[-1]['num_chs']
12 | return backbone, feature_dim
13 |
14 | class PoseEncoder(nn.Module):
15 | def __init__(self) -> None:
16 | super().__init__()
17 |
18 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_small_minimal_100')
19 |
20 | self.pose_cam_layers = nn.Sequential(
21 | nn.Linear(feature_dim, 6)
22 | )
23 |
24 | self.init_weights()
25 |
26 | def init_weights(self):
27 | self.pose_cam_layers[-1].weight.data *= 0.001
28 | self.pose_cam_layers[-1].bias.data *= 0.001
29 |
30 | self.pose_cam_layers[-1].weight.data[3] = 0
31 | self.pose_cam_layers[-1].bias.data[3] = 7
32 |
33 |
34 | def forward(self, img):
35 | features = self.encoder(img)[-1]
36 |
37 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
38 |
39 | outputs = {}
40 |
41 | pose_cam = self.pose_cam_layers(features).reshape(img.size(0), -1)
42 | outputs['pose_params'] = pose_cam[...,:3]
43 | outputs['cam'] = pose_cam[...,3:]
44 |
45 | return outputs
46 |
47 |
48 | class ShapeEncoder(nn.Module):
49 | def __init__(self, n_shape=300) -> None:
50 | super().__init__()
51 |
52 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100')
53 |
54 | self.shape_layers = nn.Sequential(
55 | nn.Linear(feature_dim, n_shape)
56 | )
57 |
58 | self.init_weights()
59 |
60 |
61 | def init_weights(self):
62 | self.shape_layers[-1].weight.data *= 0
63 | self.shape_layers[-1].bias.data *= 0
64 |
65 |
66 | def forward(self, img):
67 | features = self.encoder(img)[-1]
68 |
69 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
70 |
71 | parameters = self.shape_layers(features).reshape(img.size(0), -1)
72 |
73 | return {'shape_params': parameters}
74 |
75 |
76 | class ExpressionEncoder(nn.Module):
77 | def __init__(self, n_exp=50) -> None:
78 | super().__init__()
79 |
80 | self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100')
81 |
82 | self.expression_layers = nn.Sequential(
83 | nn.Linear(feature_dim, n_exp+2+3) # num expressions + jaw + eyelid
84 | )
85 |
86 | self.n_exp = n_exp
87 | self.init_weights()
88 |
89 |
90 | def init_weights(self):
91 | self.expression_layers[-1].weight.data *= 0.1
92 | self.expression_layers[-1].bias.data *= 0.1
93 |
94 |
95 | def forward(self, img):
96 | features = self.encoder(img)[-1]
97 |
98 | features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
99 |
100 |
101 | parameters = self.expression_layers(features).reshape(img.size(0), -1)
102 |
103 | outputs = {}
104 |
105 | outputs['expression_params'] = parameters[...,:self.n_exp]
106 | outputs['eyelid_params'] = torch.clamp(parameters[...,self.n_exp:self.n_exp+2], 0, 1)
107 | outputs['jaw_params'] = torch.cat([F.relu(parameters[...,self.n_exp+2].unsqueeze(-1)),
108 | torch.clamp(parameters[...,self.n_exp+3:self.n_exp+5], -.2, .2)], dim=-1)
109 |
110 | return outputs
111 |
112 |
113 | class SmirkEncoder(nn.Module):
114 | def __init__(self, n_exp=50, n_shape=300) -> None:
115 | super().__init__()
116 |
117 | self.pose_encoder = PoseEncoder()
118 |
119 | self.shape_encoder = ShapeEncoder(n_shape=n_shape)
120 |
121 | self.expression_encoder = ExpressionEncoder(n_exp=n_exp)
122 |
123 | def forward(self, img):
124 | pose_outputs = self.pose_encoder(img)
125 | shape_outputs = self.shape_encoder(img)
126 | expression_outputs = self.expression_encoder(img)
127 |
128 | outputs = {}
129 | outputs.update(pose_outputs)
130 | outputs.update(shape_outputs)
131 | outputs.update(expression_outputs)
132 |
133 | return outputs
134 |
--------------------------------------------------------------------------------
/decalib/smirk/utils/masking.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import numpy as np
5 | import torch.nn.functional as F
6 | import cv2
7 |
8 |
9 |
10 | def load_probabilities_per_FLAME_triangle():
11 | """
12 | FLAME_masks_triangles.npy contains for each face area the indices of the triangles that belong to that area.
13 | Using that, we can assign a probability to each triangle based on the area it belongs to, and then sample for masking.
14 | """
15 | flame_masks_triangles = np.load('assets/FLAME_masks/FLAME_masks_triangles.npy', allow_pickle=True).item()
16 |
17 | area_weights = {
18 | 'neck': 0.0,
19 | 'right_eyeball': 0.0,
20 | 'right_ear': 0.0,
21 | 'lips': 0.5,
22 | 'nose': 0.5,
23 | 'left_ear': 0.0,
24 | 'eye_region': 1.0,
25 | 'forehead':1.0,
26 | 'left_eye_region': 1.0,
27 | 'right_eye_region': 1.0,
28 | 'face_clean': 1.0,
29 | 'cleaner_lips': 1.0
30 | }
31 |
32 | face_probabilities = torch.zeros(9976)
33 |
34 | for area in area_weights.keys():
35 | face_probabilities[flame_masks_triangles[area]] = area_weights[area]
36 |
37 | return face_probabilities
38 |
39 |
40 | def triangle_area(vertices):
41 | # Using the Shoelace formula to calculate the area of triangles in the xy plane
42 | # vertices is expected to be of shape (..., 3, 2) where the last dimension holds x and y coordinates.
43 | x1, y1 = vertices[..., 0, 0], vertices[..., 0, 1]
44 | x2, y2 = vertices[..., 1, 0], vertices[..., 1, 1]
45 | x3, y3 = vertices[..., 2, 0], vertices[..., 2, 1]
46 |
47 | # Shoelace formula for the area of a triangle given by coordinates (x1, y1), (x2, y2), (x3, y3)
48 | area = 0.5 * torch.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3)
49 | return area
50 |
51 |
52 |
53 | def random_barycentric(num=1):
54 | # Generate two random numbers for each set
55 | u, v = torch.rand(num), torch.rand(num)
56 |
57 | # Adjust the random numbers if they are outside the triangle
58 | outside_triangle = u + v > 1
59 | u[outside_triangle], v[outside_triangle] = 1 - u[outside_triangle], 1 - v[outside_triangle]
60 |
61 | # Calculate the barycentric coordinates
62 | alpha = 1 - (u + v)
63 | beta = u
64 | gamma = v
65 |
66 | # Combine the coordinates into a single tensor
67 | return torch.stack((alpha, beta, gamma), dim=1)
68 |
69 |
70 | def masking(img, mask, extra_points, wr=15, rendered_mask=None, extra_noise=True, random_mask=0.01):
71 | # img: B x C x H x W
72 | # mask: B x 1 x H x W
73 |
74 | B, C, H, W = img.size()
75 |
76 | # dilate face mask, drawn from convex hull of face landmarks
77 | mask = 1-F.max_pool2d(1-mask, 2 * wr + 1, stride=1, padding=wr)
78 |
79 | # optionally remove the rendered mask
80 | if rendered_mask is not None:
81 | mask = mask * (1 - rendered_mask)
82 |
83 | masked_img = img * mask
84 | # add noise to extra in-face points
85 | if extra_noise:
86 | # normal around 1 with std 0.1
87 | noise_mult = torch.randn(extra_points.shape).to(img.device) * 0.05 + 1
88 | extra_points = extra_points * noise_mult
89 |
90 | # select random_mask percentage of pixels as centers to crop out 11x11 patches
91 | if random_mask > 0:
92 | random_mask = torch.bernoulli(torch.ones((B, 1, H, W)) * random_mask).to(img.device)
93 | # dilate the mask to have 11x11 patches
94 | random_mask = 1 - F.max_pool2d(random_mask, 11, stride=1, padding=5)
95 |
96 | extra_points = extra_points * random_mask
97 |
98 | masked_img[extra_points > 0] = extra_points[extra_points > 0]
99 |
100 | masked_img = masked_img.detach()
101 | return masked_img
102 |
103 |
104 |
105 | def point2ind(npoints, H):
106 |
107 | npoints = npoints * (H // 2) + H // 2
108 | npoints = npoints.long()
109 | npoints[...,1] = torch.clamp(npoints[..., 1], 0, H-1)
110 | npoints[...,0] = torch.clamp(npoints[..., 0], 0, H-1)
111 |
112 | return npoints
113 |
114 |
115 | def transfer_pixels(img, points1, points2, rbound=None):
116 |
117 | B, C, H, W = img.size()
118 | retained_pixels = torch.zeros_like(img).to(img.device)
119 |
120 | if rbound is not None:
121 | for bi in range(B):
122 | retained_pixels[bi, :, points2[bi, :rbound[bi], 1], points2[bi, :rbound[bi], 0]] = \
123 | img[bi, :, points1[bi, :rbound[bi], 1], points1[bi, :rbound[bi], 0]]
124 | else:
125 | retained_pixels[torch.arange(B).unsqueeze(-1), :, points2[..., 1], points2[..., 0]] = \
126 | img[torch.arange(B).unsqueeze(-1), :, points1[..., 1], points1[..., 0]]
127 |
128 | return retained_pixels
129 |
130 |
131 | def mesh_based_mask_uniform_faces(flame_trans_verts, flame_faces, face_probabilities, mask_ratio=0.1, coords=None, IMAGE_SIZE=224):
132 | """
133 | This function samples points from the FLAME mesh based on the face probabilities and the mask ratio.
134 | """
135 | batch_size = flame_trans_verts.size(0)
136 | DEVICE = flame_trans_verts.device
137 |
138 | # if mask_ratio is single value, then use it as a ratio of the image size
139 | num_points_to_sample = int(mask_ratio * IMAGE_SIZE * IMAGE_SIZE)
140 |
141 | flame_faces_expanded = flame_faces.expand(batch_size, -1, -1)
142 |
143 | if coords is None:
144 | # calculate face normals
145 | transformed_normals = vertex_normals(flame_trans_verts, flame_faces_expanded)
146 | transformed_face_normals = face_vertices(transformed_normals, flame_faces_expanded)
147 | transformed_face_normals = transformed_face_normals[:,:,:,2].mean(dim=-1)
148 | face_probabilities = face_probabilities.repeat(batch_size,1).to(flame_trans_verts.device)
149 |
150 | # # where the face normals are negative, set probability to 0
151 | face_probabilities = torch.where(transformed_face_normals < 0.05, face_probabilities, torch.zeros_like(transformed_face_normals).to(DEVICE))
152 | # face_probabilities = torch.where(transformed_face_normals > 0, torch.ones_like(transformed_face_normals).to(flame_trans_verts.device), face_probabilities)
153 |
154 | # calculate xy area of faces and scale the probabilities by it
155 | fv = face_vertices(flame_trans_verts, flame_faces_expanded)
156 | xy_area = triangle_area(fv)
157 |
158 | face_probabilities = face_probabilities * xy_area
159 |
160 |
161 | sampled_faces_indices = torch.multinomial(face_probabilities, num_points_to_sample, replacement=True).to(DEVICE)
162 |
163 | barycentric_coords = random_barycentric(num=batch_size*num_points_to_sample).to(DEVICE)
164 | barycentric_coords = barycentric_coords.view(batch_size, num_points_to_sample, 3)
165 | else:
166 | sampled_faces_indices = coords['sampled_faces_indices']
167 | barycentric_coords = coords['barycentric_coords']
168 |
169 | npoints = vertices2landmarks(flame_trans_verts, flame_faces, sampled_faces_indices, barycentric_coords)
170 |
171 | npoints = .5 * (1 + npoints) * IMAGE_SIZE
172 | npoints = npoints.long()
173 | npoints[...,1] = torch.clamp(npoints[..., 1], 0, IMAGE_SIZE-1)
174 | npoints[...,0] = torch.clamp(npoints[..., 0], 0, IMAGE_SIZE-1)
175 |
176 | #mask = torch.zeros((flame_output['trans_verts'].size(0), 1, self.config.image_size, self.config.image_size)).to(flame_output['trans_verts'].device)
177 |
178 | #mask[torch.arange(batch_size).unsqueeze(-1), :, npoints[..., 1], npoints[..., 0]] = 1
179 |
180 | return npoints, {'sampled_faces_indices':sampled_faces_indices, 'barycentric_coords':barycentric_coords}
181 |
--------------------------------------------------------------------------------
/decalib/smirk/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | def load_templates():
6 | templates_path = "assets/expression_templates_famos"
7 | classes_to_load = ["lips_back", "rolling_lips", "mouth_side", "kissing", "high_smile", "mouth_up",
8 | "mouth_middle", "mouth_down", "blow_cheeks", "cheeks_in", "jaw", "lips_up"]
9 | templates = {}
10 | for subject in os.listdir(templates_path):
11 | if os.path.isdir(os.path.join(templates_path, subject)):
12 | for template in os.listdir(os.path.join(templates_path, subject)):
13 | if template.endswith(".mp4"):
14 | continue
15 | if template not in classes_to_load:
16 | continue
17 | exps = []
18 | for npy_file in os.listdir(os.path.join(templates_path, subject, template)):
19 | params = np.load(os.path.join(templates_path, subject, template, npy_file), allow_pickle=True)
20 | exp = params.item()['expression'].squeeze()
21 | exps.append(exp)
22 | templates[subject+template] = np.array(exps)
23 | print('Number of expression templates loaded: ', len(templates.keys()))
24 |
25 | return templates
26 |
27 |
28 |
29 | def tensor_to_image(image_tensor):
30 | """Converts a tensor to a numpy image."""
31 | image = image_tensor.permute(1,2,0).cpu().numpy()*255.0
32 | image = np.clip(image, 0, 255)
33 | image = image.astype(np.uint8)
34 | return image
35 |
36 | def image_to_tensor(image):
37 | """Converts a numpy image to a tensor."""
38 | image = torch.from_numpy(image).permute(2,0,1).float()/255.0
39 | return image
40 |
41 |
42 | def count_parameters(model):
43 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
44 |
45 |
46 | def freeze_module(module, module_name=None):
47 |
48 | for param in module.parameters():
49 | param.requires_grad_(False)
50 |
51 | module.eval()
52 |
53 |
54 | def unfreeze_module(module, module_name=None):
55 |
56 | for param in module.parameters():
57 | param.requires_grad_(True)
58 |
59 | module.train()
60 |
61 | import cv2
62 | from torchvision.utils import make_grid
63 |
64 |
65 | def batch_draw_keypoints(images, landmarks, color=(255, 255, 255), radius=1):
66 | if isinstance(landmarks, torch.Tensor):
67 | landmarks = landmarks.cpu().numpy()
68 | landmarks = landmarks.copy()*112 + 112
69 |
70 | if isinstance(images, torch.Tensor):
71 | images = images.cpu().numpy().transpose(0, 2, 3, 1)
72 | images = (images * 255).astype('uint8')
73 | images = np.ascontiguousarray(images[..., ::-1])
74 |
75 | plotted_images = []
76 | for image, landmark in zip(images, landmarks):
77 | for point in landmark:
78 | image = cv2.circle(image, (int(point[0]), int(point[1])), radius, color, -1)
79 | plotted_images.append(image)
80 |
81 | return plotted_images
82 |
83 | def make_grid_from_opencv_images(images, nrow=12):
84 | """ Create a grid of images from the list of cv2 images in images"""
85 | images = np.array(images)
86 | images = images[..., ::-1]
87 | images = np.array(images)
88 | images = torch.from_numpy(images).permute(0, 3, 1, 2).float()/255.
89 | grid = make_grid(images, nrow=nrow)
90 | return grid
--------------------------------------------------------------------------------
/decalib/utils/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/utils/__pycache__/renderer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/renderer.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/utils/__pycache__/rotation_converter.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/rotation_converter.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/utils/__pycache__/tensor_cropper.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/tensor_cropper.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/utils/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/decalib/utils/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | Default config for DECA
3 | '''
4 | from yacs.config import CfgNode as CN
5 | import argparse
6 | import yaml
7 | import os
8 |
9 | cfg = CN()
10 |
11 | abs_deca_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
12 | abs_deca_dir = os.path.join(abs_deca_dir, 'decalib')
13 | cfg.deca_dir = abs_deca_dir
14 | cfg.device = 'cuda'
15 | cfg.device_id = '0'
16 |
17 | cfg.pretrained_modelpath = os.path.join(cfg.deca_dir, 'data', 'deca_model.tar')
18 | cfg.output_dir = ''
19 | cfg.rasterizer_type = 'pytorch3d'
20 | # ---------------------------------------------------------------------------- #
21 | # Options for Face model
22 | # ---------------------------------------------------------------------------- #
23 | cfg.model = CN()
24 | cfg.model.topology_path = os.path.join(cfg.deca_dir, 'data', 'head_template.obj')
25 | # texture data original from http://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_texture_data.zip
26 | cfg.model.dense_template_path = os.path.join(cfg.deca_dir, 'data', 'texture_data_256.npy')
27 | cfg.model.fixed_displacement_path = os.path.join(cfg.deca_dir, 'data', 'fixed_displacement_256.npy')
28 | cfg.model.flame_model_path = os.path.join(cfg.deca_dir, 'data', 'generic_model.pkl')
29 | cfg.model.flame_lmk_embedding_path = os.path.join(cfg.deca_dir, 'data', 'landmark_embedding.npy')
30 | cfg.model.face_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_mask.png')
31 | cfg.model.face_eye_mask_path = os.path.join(cfg.deca_dir, 'data', 'uv_face_eye_mask.png')
32 | cfg.model.mean_tex_path = os.path.join(cfg.deca_dir, 'data', 'mean_texture.jpg')
33 | cfg.model.tex_path = os.path.join(cfg.deca_dir, 'data', 'FLAME_albedo_from_BFM.npz')
34 | cfg.model.tex_type = 'BFM' # BFM, FLAME, albedoMM
35 | cfg.model.uv_size = 256
36 | cfg.model.param_list = ['shape', 'tex', 'exp', 'pose', 'cam', 'light']
37 | cfg.model.n_shape = 100
38 | cfg.model.n_tex = 50
39 | cfg.model.n_exp = 50
40 | cfg.model.n_cam = 3
41 | cfg.model.n_pose = 6
42 | cfg.model.n_light = 27
43 | cfg.model.use_tex = True
44 | cfg.model.jaw_type = 'aa' # default use axis angle, another option: euler. Note that: aa is not stable in the beginning
45 | # face recognition model
46 | cfg.model.fr_model_path = os.path.join(cfg.deca_dir, 'data', 'resnet50_ft_weight.pkl')
47 |
48 | ## details
49 | cfg.model.n_detail = 128
50 | cfg.model.max_z = 0.01
51 |
52 | # ---------------------------------------------------------------------------- #
53 | # Options for Dataset
54 | # ---------------------------------------------------------------------------- #
55 | cfg.dataset = CN()
56 | cfg.dataset.training_data = ['vggface2', 'ethnicity']
57 | # cfg.dataset.training_data = ['ethnicity']
58 | cfg.dataset.eval_data = ['aflw2000']
59 | cfg.dataset.test_data = ['']
60 | cfg.dataset.batch_size = 2
61 | cfg.dataset.K = 4
62 | cfg.dataset.isSingle = False
63 | cfg.dataset.num_workers = 2
64 | cfg.dataset.image_size = 224
65 | cfg.dataset.scale_min = 1.4
66 | cfg.dataset.scale_max = 1.8
67 | cfg.dataset.trans_scale = 0.
68 |
69 | # ---------------------------------------------------------------------------- #
70 | # Options for training
71 | # ---------------------------------------------------------------------------- #
72 | cfg.train = CN()
73 | cfg.train.train_detail = False
74 | cfg.train.max_epochs = 500
75 | cfg.train.max_steps = 1000000
76 | cfg.train.lr = 1e-4
77 | cfg.train.log_dir = 'logs'
78 | cfg.train.log_steps = 10
79 | cfg.train.vis_dir = 'train_images'
80 | cfg.train.vis_steps = 200
81 | cfg.train.write_summary = True
82 | cfg.train.checkpoint_steps = 500
83 | cfg.train.val_steps = 500
84 | cfg.train.val_vis_dir = 'val_images'
85 | cfg.train.eval_steps = 5000
86 | cfg.train.resume = True
87 |
88 | # ---------------------------------------------------------------------------- #
89 | # Options for Losses
90 | # ---------------------------------------------------------------------------- #
91 | cfg.loss = CN()
92 | cfg.loss.lmk = 1.0
93 | cfg.loss.useWlmk = True
94 | cfg.loss.eyed = 1.0
95 | cfg.loss.lipd = 0.5
96 | cfg.loss.photo = 2.0
97 | cfg.loss.useSeg = True
98 | cfg.loss.id = 0.2
99 | cfg.loss.id_shape_only = True
100 | cfg.loss.reg_shape = 1e-04
101 | cfg.loss.reg_exp = 1e-04
102 | cfg.loss.reg_tex = 1e-04
103 | cfg.loss.reg_light = 1.
104 | cfg.loss.reg_jaw_pose = 0. #1.
105 | cfg.loss.use_gender_prior = False
106 | cfg.loss.shape_consistency = True
107 | # loss for detail
108 | cfg.loss.detail_consistency = True
109 | cfg.loss.useConstraint = True
110 | cfg.loss.mrf = 5e-2
111 | cfg.loss.photo_D = 2.
112 | cfg.loss.reg_sym = 0.005
113 | cfg.loss.reg_z = 0.005
114 | cfg.loss.reg_diff = 0.005
115 |
116 |
117 | def get_cfg_defaults():
118 | """Get a yacs CfgNode object with default values for my_project."""
119 | # Return a clone so that the defaults will not be altered
120 | # This is for the "local variable" use pattern
121 | return cfg.clone()
122 |
123 | def update_cfg(cfg, cfg_file):
124 | cfg.merge_from_file(cfg_file)
125 | return cfg.clone()
126 |
127 | def parse_args():
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--cfg', type=str, help='cfg file path')
130 | parser.add_argument('--mode', type=str, default = 'train', help='deca mode')
131 |
132 | args = parser.parse_args()
133 | print(args, end='\n\n')
134 |
135 | cfg = get_cfg_defaults()
136 | cfg.cfg_file = None
137 | cfg.mode = args.mode
138 | # import ipdb; ipdb.set_trace()
139 | if args.cfg is not None:
140 | cfg_file = args.cfg
141 | cfg = update_cfg(cfg, args.cfg)
142 | cfg.cfg_file = cfg_file
143 |
144 | return cfg
145 |
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/INSTALL.md:
--------------------------------------------------------------------------------
1 | ## Install
2 | from standard_rasterize_cuda import standard_rasterize
3 | # from .rasterizer.standard_rasterize_cuda import standard_rasterize
4 |
5 | in this folder, run
6 | ```python setup.py build_ext -i ```
7 |
8 | then remember to set --rasterizer_type=standard when runing demos :)
9 |
10 | ## Alg
11 | https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation
12 |
13 | ## Speed Comparison
14 | runtime for raterization only
15 | In PIXIE, number of faces in SMPLX: 20908
16 |
17 | for image size = 1024
18 | pytorch3d: 0.031s
19 | standard: 0.01s
20 |
21 | for image size = 224
22 | pytorch3d: 0.0035s
23 | standard: 0.0014s
24 |
25 | why standard rasterizer is faster than pytorch3d?
26 | Ref: https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu
27 | pytorch3d: for each pixel in image space (each pixel is parallel in cuda), loop through the faces, check if this pixel is in the projection bounding box of the face, then sorting faces according to z, record the face id of closest K faces.
28 | standard rasterization: for each face in mesh (each face is parallel in cuda), loop through pixels in the projection bounding box (normally a very samll number), compare z, record face id of that pixel
29 |
30 |
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/decalib/utils/rasterizer/__init__.py
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/setup.py:
--------------------------------------------------------------------------------
1 | # To install, run
2 | # python setup.py build_ext -i
3 | # Ref: https://github.com/pytorch/pytorch/blob/11a40410e755b1fe74efe9eaa635e7ba5712846b/test/cpp_extensions/setup.py#L62
4 |
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 | import os
8 |
9 | # USE_NINJA = os.getenv('USE_NINJA') == '1'
10 | os.environ["CC"] = "gcc-7"
11 | os.environ["CXX"] = "gcc-7"
12 |
13 | USE_NINJA = os.getenv('USE_NINJA') == '1'
14 |
15 | setup(
16 | name='standard_rasterize_cuda',
17 | ext_modules=[
18 | CUDAExtension('standard_rasterize_cuda', [
19 | 'standard_rasterize_cuda.cpp',
20 | 'standard_rasterize_cuda_kernel.cu',
21 | ])
22 | ],
23 | cmdclass={'build_ext': BuildExtension.with_options(use_ninja=USE_NINJA)}
24 | )
25 |
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/standard_rasterize_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | std::vector forward_rasterize_cuda(
6 | at::Tensor face_vertices,
7 | at::Tensor depth_buffer,
8 | at::Tensor triangle_buffer,
9 | at::Tensor baryw_buffer,
10 | int h,
11 | int w);
12 |
13 | std::vector standard_rasterize(
14 | at::Tensor face_vertices,
15 | at::Tensor depth_buffer,
16 | at::Tensor triangle_buffer,
17 | at::Tensor baryw_buffer,
18 | int height, int width
19 | ) {
20 | return forward_rasterize_cuda(face_vertices, depth_buffer, triangle_buffer, baryw_buffer, height, width);
21 | }
22 |
23 | std::vector forward_rasterize_colors_cuda(
24 | at::Tensor face_vertices,
25 | at::Tensor face_colors,
26 | at::Tensor depth_buffer,
27 | at::Tensor triangle_buffer,
28 | at::Tensor images,
29 | int h,
30 | int w);
31 |
32 | std::vector standard_rasterize_colors(
33 | at::Tensor face_vertices,
34 | at::Tensor face_colors,
35 | at::Tensor depth_buffer,
36 | at::Tensor triangle_buffer,
37 | at::Tensor images,
38 | int height, int width
39 | ) {
40 | return forward_rasterize_colors_cuda(face_vertices, face_colors, depth_buffer, triangle_buffer, images, height, width);
41 | }
42 |
43 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
44 | m.def("standard_rasterize", &standard_rasterize, "RASTERIZE (CUDA)");
45 | m.def("standard_rasterize_colors", &standard_rasterize_colors, "RASTERIZE COLORS (CUDA)");
46 | }
47 |
48 | // TODO: backward
--------------------------------------------------------------------------------
/decalib/utils/rasterizer/standard_rasterize_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | // Ref: https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/cuda/rasterize_cuda_kernel.cu
2 | // https://github.com/YadiraF/face3d/blob/master/face3d/mesh/cython/mesh_core.cpp
3 |
4 | #include
5 |
6 | #include
7 | #include
8 |
9 | namespace{
10 | __device__ __forceinline__ float atomicMin(float* address, float val)
11 | {
12 | int* address_as_i = (int*) address;
13 | int old = *address_as_i, assumed;
14 | do {
15 | assumed = old;
16 | old = atomicCAS(address_as_i, assumed,
17 | __float_as_int(fminf(val, __int_as_float(assumed))));
18 | } while (assumed != old);
19 | return __int_as_float(old);
20 | }
21 | __device__ __forceinline__ double atomicMin(double* address, double val)
22 | {
23 | unsigned long long int* address_as_i = (unsigned long long int*) address;
24 | unsigned long long int old = *address_as_i, assumed;
25 | do {
26 | assumed = old;
27 | old = atomicCAS(address_as_i, assumed,
28 | __double_as_longlong(fminf(val, __longlong_as_double(assumed))));
29 | } while (assumed != old);
30 | return __longlong_as_double(old);
31 | }
32 |
33 | template
34 | __device__ __forceinline__ bool check_face_frontside(const scalar_t *face) {
35 | return (face[7] - face[1]) * (face[3] - face[0]) < (face[4] - face[1]) * (face[6] - face[0]);
36 | }
37 |
38 |
39 | template struct point
40 | {
41 | public:
42 | scalar_t x;
43 | scalar_t y;
44 |
45 | __host__ __device__ scalar_t dot(point p)
46 | {
47 | return this->x * p.x + this->y * p.y;
48 | };
49 |
50 | __host__ __device__ point operator-(point& p)
51 | {
52 | point np;
53 | np.x = this->x - p.x;
54 | np.y = this->y - p.y;
55 | return np;
56 | };
57 |
58 | __host__ __device__ point operator+(point& p)
59 | {
60 | point np;
61 | np.x = this->x + p.x;
62 | np.y = this->y + p.y;
63 | return np;
64 | };
65 |
66 | __host__ __device__ point operator*(scalar_t s)
67 | {
68 | point np;
69 | np.x = s * this->x;
70 | np.y = s * this->y;
71 | return np;
72 | };
73 | };
74 |
75 | template
76 | __device__ __forceinline__ bool check_pixel_inside(const scalar_t *w) {
77 | return w[0] <= 1 && w[0] >= 0 && w[1] <= 1 && w[1] >= 0 && w[2] <= 1 && w[2] >= 0;
78 | }
79 |
80 | template
81 | __device__ __forceinline__ void barycentric_weight(scalar_t *w, point p, point p0, point p1, point p2) {
82 |
83 | // vectors
84 | point v0, v1, v2;
85 | scalar_t s = p.dot(p);
86 | v0 = p2 - p0;
87 | v1 = p1 - p0;
88 | v2 = p - p0;
89 |
90 | // dot products
91 | scalar_t dot00 = v0.dot(v0); //v0.x * v0.x + v0.y * v0.y //np.dot(v0.T, v0)
92 | scalar_t dot01 = v0.dot(v1); //v0.x * v1.x + v0.y * v1.y //np.dot(v0.T, v1)
93 | scalar_t dot02 = v0.dot(v2); //v0.x * v2.x + v0.y * v2.y //np.dot(v0.T, v2)
94 | scalar_t dot11 = v1.dot(v1); //v1.x * v1.x + v1.y * v1.y //np.dot(v1.T, v1)
95 | scalar_t dot12 = v1.dot(v2); //v1.x * v2.x + v1.y * v2.y//np.dot(v1.T, v2)
96 |
97 | // barycentric coordinates
98 | scalar_t inverDeno;
99 | if(dot00*dot11 - dot01*dot01 == 0)
100 | inverDeno = 0;
101 | else
102 | inverDeno = 1/(dot00*dot11 - dot01*dot01);
103 |
104 | scalar_t u = (dot11*dot02 - dot01*dot12)*inverDeno;
105 | scalar_t v = (dot00*dot12 - dot01*dot02)*inverDeno;
106 |
107 | // weight
108 | w[0] = 1 - u - v;
109 | w[1] = v;
110 | w[2] = u;
111 | }
112 |
113 | // Ref: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation/overview-rasterization-algorithm
114 | template
115 | __global__ void forward_rasterize_cuda_kernel(
116 | const scalar_t* __restrict__ face_vertices, //[bz, nf, 3, 3]
117 | scalar_t* depth_buffer,
118 | int* triangle_buffer,
119 | scalar_t* baryw_buffer,
120 | int batch_size, int h, int w,
121 | int ntri) {
122 |
123 | const int i = blockIdx.x * blockDim.x + threadIdx.x;
124 | if (i >= batch_size * ntri) {
125 | return;
126 | }
127 | int bn = i/ntri;
128 | const scalar_t* face = &face_vertices[i * 9];
129 | scalar_t bw[3];
130 | point p0, p1, p2, p;
131 |
132 | p0.x = face[0]; p0.y=face[1];
133 | p1.x = face[3]; p1.y=face[4];
134 | p2.x = face[6]; p2.y=face[7];
135 |
136 | int x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0);
137 | int x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1);
138 | int y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0);
139 | int y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1);
140 |
141 | for(int y = y_min; y <= y_max; y++) //h
142 | {
143 | for(int x = x_min; x <= x_max; x++) //w
144 | {
145 | p.x = x; p.y = y;
146 | barycentric_weight(bw, p, p0, p1, p2);
147 | // if(((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) && check_face_frontside(face))
148 | if((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0))
149 | {
150 | // perspective correct: https://www.scratchapixel.com/lessons/3d-basic-rendering/rasterization-practical-implementation/perspective-correct-interpolation-vertex-attributes
151 | scalar_t zp = 1. / (bw[0] / face[2] + bw[1] / face[5] + bw[2] / face[8]);
152 | // printf("%f %f %f \n", (float)zp, (float)face[2], (float)bw[2]);
153 | atomicMin(&depth_buffer[bn*h*w + y*w + x], zp);
154 | if(depth_buffer[bn*h*w + y*w + x] == zp)
155 | {
156 | triangle_buffer[bn*h*w + y*w + x] = (int)(i%ntri);
157 | for(int k=0; k<3; k++){
158 | baryw_buffer[bn*h*w*3 + y*w*3 + x*3 + k] = bw[k];
159 | }
160 | }
161 | }
162 | }
163 | }
164 |
165 | }
166 |
167 | template
168 | __global__ void forward_rasterize_colors_cuda_kernel(
169 | const scalar_t* __restrict__ face_vertices, //[bz, nf, 3, 3]
170 | const scalar_t* __restrict__ face_colors, //[bz, nf, 3, 3]
171 | scalar_t* depth_buffer,
172 | int* triangle_buffer,
173 | scalar_t* images,
174 | int batch_size, int h, int w,
175 | int ntri) {
176 | const int i = blockIdx.x * blockDim.x + threadIdx.x;
177 | if (i >= batch_size * ntri) {
178 | return;
179 | }
180 | int bn = i/ntri;
181 | const scalar_t* face = &face_vertices[i * 9];
182 | const scalar_t* color = &face_colors[i * 9];
183 | scalar_t bw[3];
184 | point p0, p1, p2, p;
185 |
186 | p0.x = face[0]; p0.y=face[1];
187 | p1.x = face[3]; p1.y=face[4];
188 | p2.x = face[6]; p2.y=face[7];
189 | scalar_t cl[3][3];
190 | for (int num = 0; num < 3; num++) {
191 | for (int dim = 0; dim < 3; dim++) {
192 | cl[num][dim] = color[3 * num + dim]; //[3p,3rgb]
193 | }
194 | }
195 | int x_min = max((int)ceil(min(p0.x, min(p1.x, p2.x))), 0);
196 | int x_max = min((int)floor(max(p0.x, max(p1.x, p2.x))), w - 1);
197 | int y_min = max((int)ceil(min(p0.y, min(p1.y, p2.y))), 0);
198 | int y_max = min((int)floor(max(p0.y, max(p1.y, p2.y))), h - 1);
199 |
200 | for(int y = y_min; y <= y_max; y++) //h
201 | {
202 | for(int x = x_min; x <= x_max; x++) //w
203 | {
204 | p.x = x; p.y = y;
205 | barycentric_weight(bw, p, p0, p1, p2);
206 | if(((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0)) && check_face_frontside(face))
207 | // if((bw[2] >= 0) && (bw[1] >= 0) && (bw[0]>0))
208 | {
209 | scalar_t zp = 1. / (bw[0] / face[2] + bw[1] / face[5] + bw[2] / face[8]);
210 |
211 | atomicMin(&depth_buffer[bn*h*w + y*w + x], zp);
212 | if(depth_buffer[bn*h*w + y*w + x] == zp)
213 | {
214 | triangle_buffer[bn*h*w + y*w + x] = (int)(i%ntri);
215 | for(int k=0; k<3; k++){
216 | // baryw_buffer[bn*h*w*3 + y*w*3 + x*3 + k] = bw[k];
217 | images[bn*h*w*3 + y*w*3 + x*3 + k] = bw[0]*cl[0][k] + bw[1]*cl[1][k] + bw[2]*cl[2][k];
218 | }
219 | // buffers[bn*h*w*2 + y*w*2 + x*2 + 1] = p_depth;
220 | }
221 | }
222 | }
223 | }
224 |
225 | }
226 |
227 | }
228 |
229 | std::vector forward_rasterize_cuda(
230 | at::Tensor face_vertices,
231 | at::Tensor depth_buffer,
232 | at::Tensor triangle_buffer,
233 | at::Tensor baryw_buffer,
234 | int h,
235 | int w){
236 |
237 | const auto batch_size = face_vertices.size(0);
238 | const auto ntri = face_vertices.size(1);
239 |
240 | // print(channel_size)
241 | const int threads = 512;
242 | const dim3 blocks_1 ((batch_size * ntri - 1) / threads +1);
243 |
244 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_cuda1", ([&] {
245 | forward_rasterize_cuda_kernel<<>>(
246 | face_vertices.data(),
247 | depth_buffer.data(),
248 | triangle_buffer.data(),
249 | baryw_buffer.data(),
250 | batch_size, h, w,
251 | ntri);
252 | }));
253 |
254 | // better to do it twice (or there will be balck spots in the rendering)
255 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_cuda2", ([&] {
256 | forward_rasterize_cuda_kernel<<>>(
257 | face_vertices.data(),
258 | depth_buffer.data(),
259 | triangle_buffer.data(),
260 | baryw_buffer.data(),
261 | batch_size, h, w,
262 | ntri);
263 | }));
264 | cudaError_t err = cudaGetLastError();
265 | if (err != cudaSuccess)
266 | printf("Error in forward_rasterize_cuda_kernel: %s\n", cudaGetErrorString(err));
267 |
268 | return {depth_buffer, triangle_buffer, baryw_buffer};
269 | }
270 |
271 |
272 | std::vector forward_rasterize_colors_cuda(
273 | at::Tensor face_vertices,
274 | at::Tensor face_colors,
275 | at::Tensor depth_buffer,
276 | at::Tensor triangle_buffer,
277 | at::Tensor images,
278 | int h,
279 | int w){
280 |
281 | const auto batch_size = face_vertices.size(0);
282 | const auto ntri = face_vertices.size(1);
283 |
284 | // print(channel_size)
285 | const int threads = 512;
286 | const dim3 blocks_1 ((batch_size * ntri - 1) / threads +1);
287 | //initial
288 |
289 | AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_colors_cuda", ([&] {
290 | forward_rasterize_colors_cuda_kernel<<>>(
291 | face_vertices.data(),
292 | face_colors.data(),
293 | depth_buffer.data(),
294 | triangle_buffer.data(),
295 | images.data(),
296 | batch_size, h, w,
297 | ntri);
298 | }));
299 | // better to do it twice
300 | // AT_DISPATCH_FLOATING_TYPES(face_vertices.type(), "forward_rasterize_colors_cuda", ([&] {
301 | // forward_rasterize_colors_cuda_kernel<<>>(
302 | // face_vertices.data(),
303 | // face_colors.data(),
304 | // depth_buffer.data(),
305 | // triangle_buffer.data(),
306 | // images.data(),
307 | // batch_size, h, w,
308 | // ntri);
309 | // }));
310 | cudaError_t err = cudaGetLastError();
311 | if (err != cudaSuccess)
312 | printf("Error in forward_rasterize_cuda_kernel: %s\n", cudaGetErrorString(err));
313 |
314 | return {depth_buffer, triangle_buffer, images};
315 | }
316 |
317 |
318 |
319 |
320 |
--------------------------------------------------------------------------------
/decalib/utils/tensor_cropper.py:
--------------------------------------------------------------------------------
1 | '''
2 | crop
3 | for torch tensor
4 | Given image, bbox(center, bboxsize)
5 | return: cropped image, tform(used for transform the keypoint accordingly)
6 | only support crop to squared images
7 | '''
8 | import torch
9 | from kornia.geometry.transform.imgwarp import (
10 | warp_perspective, get_perspective_transform, warp_affine
11 | )
12 |
13 | def points2bbox(points, points_scale=None):
14 | if points_scale:
15 | assert points_scale[0]==points_scale[1]
16 | points = points.clone()
17 | points[:,:,:2] = (points[:,:,:2]*0.5 + 0.5)*points_scale[0]
18 | min_coords, _ = torch.min(points, dim=1)
19 | xmin, ymin = min_coords[:, 0], min_coords[:, 1]
20 | max_coords, _ = torch.max(points, dim=1)
21 | xmax, ymax = max_coords[:, 0], max_coords[:, 1]
22 | center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5
23 |
24 | width = (xmax - xmin)
25 | height = (ymax - ymin)
26 | # Convert the bounding box to a square box
27 | size = torch.max(width, height).unsqueeze(-1)
28 | return center, size
29 |
30 | def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.):
31 | batch_size = center.shape[0]
32 | trans_scale = (torch.rand([batch_size, 2], device=center.device)*2. -1.) * trans_scale
33 | center = center + trans_scale*bbox_size # 0.5
34 | scale = torch.rand([batch_size,1], device=center.device) * (scale[1] - scale[0]) + scale[0]
35 | size = bbox_size*scale
36 | return center, size
37 |
38 | def crop_tensor(image, center, bbox_size, crop_size, interpolation = 'bilinear', align_corners=False):
39 | ''' for batch image
40 | Args:
41 | image (torch.Tensor): the reference tensor of shape BXHxWXC.
42 | center: [bz, 2]
43 | bboxsize: [bz, 1]
44 | crop_size;
45 | interpolation (str): Interpolation flag. Default: 'bilinear'.
46 | align_corners (bool): mode for grid_generation. Default: False. See
47 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details
48 | Returns:
49 | cropped_image
50 | tform
51 | '''
52 | dtype = image.dtype
53 | device = image.device
54 | batch_size = image.shape[0]
55 | # points: top-left, top-right, bottom-right, bottom-left
56 | src_pts = torch.zeros([4,2], dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, -1).contiguous()
57 |
58 | src_pts[:, 0, :] = center - bbox_size*0.5 # / (self.crop_size - 1)
59 | src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5
60 | src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5
61 | src_pts[:, 2, :] = center + bbox_size * 0.5
62 | src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5
63 | src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5
64 |
65 | DST_PTS = torch.tensor([[
66 | [0, 0],
67 | [crop_size - 1, 0],
68 | [crop_size - 1, crop_size - 1],
69 | [0, crop_size - 1],
70 | ]], dtype=dtype, device=device).expand(batch_size, -1, -1)
71 | # estimate transformation between points
72 | dst_trans_src = get_perspective_transform(src_pts, DST_PTS)
73 | # simulate broadcasting
74 | # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1)
75 |
76 | # warp images
77 | cropped_image = warp_affine(
78 | image, dst_trans_src[:, :2, :], (crop_size, crop_size),
79 | flags=interpolation, align_corners=align_corners)
80 |
81 | tform = torch.transpose(dst_trans_src, 2, 1)
82 | # tform = torch.inverse(dst_trans_src)
83 | return cropped_image, tform
84 |
85 | class Cropper(object):
86 | def __init__(self, crop_size, scale=[1,1], trans_scale = 0.):
87 | self.crop_size = crop_size
88 | self.scale = scale
89 | self.trans_scale = trans_scale
90 |
91 | def crop(self, image, points, points_scale=None):
92 | # points to bbox
93 | center, bbox_size = points2bbox(points.clone(), points_scale)
94 | # argument bbox. TODO: add rotation?
95 | center, bbox_size = augment_bbox(center, bbox_size, scale=self.scale, trans_scale=self.trans_scale)
96 | # crop
97 | cropped_image, tform = crop_tensor(image, center, bbox_size, self.crop_size)
98 | return cropped_image, tform
99 |
100 | def transform_points(self, points, tform, points_scale=None, normalize = True):
101 | points_2d = points[:,:,:2]
102 |
103 | #'input points must use original range'
104 | if points_scale:
105 | assert points_scale[0]==points_scale[1]
106 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0]
107 |
108 | batch_size, n_points, _ = points.shape
109 | trans_points_2d = torch.bmm(
110 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1),
111 | tform
112 | )
113 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1)
114 | if normalize:
115 | trans_points[:,:,:2] = trans_points[:,:,:2]/self.crop_size*2 - 1
116 | return trans_points
117 |
118 | def transform_points(points, tform, points_scale=None, out_scale=None):
119 | points_2d = points[:,:,:2]
120 |
121 | #'input points must use original range'
122 | if points_scale:
123 | assert points_scale[0]==points_scale[1]
124 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0]
125 | # import ipdb; ipdb.set_trace()
126 |
127 | batch_size, n_points, _ = points.shape
128 | trans_points_2d = torch.bmm(
129 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1),
130 | tform
131 | )
132 | if out_scale: # h,w of output image size
133 | trans_points_2d[:,:,0] = trans_points_2d[:,:,0]/out_scale[1]*2 - 1
134 | trans_points_2d[:,:,1] = trans_points_2d[:,:,1]/out_scale[0]*2 - 1
135 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1)
136 | return trans_points
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import os.path as osp
5 | from datetime import datetime
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 | from PIL import Image
11 | from diffusers import AutoencoderKL, DDIMScheduler
12 | from diffusers.utils.import_utils import is_xformers_available
13 | from omegaconf import OmegaConf
14 | from torchvision import transforms
15 | from transformers import CLIPTokenizer, CLIPTextModel
16 |
17 | from models.guidance_encoder import GuidanceEncoder
18 | from models.mgportrait_model import MgpModel
19 | from models.mutual_self_attention import ReferenceAttentionControl
20 | from models.unet_2d_condition import UNet2DConditionModel
21 | from models.unet_3d import UNet3DConditionModel
22 | from pipelines.pipeline_aggregation import MultiGuidance2LongVideoPipeline
23 | from utils.video_utils import resize_tensor_frames, save_videos_grid, pil_list_to_tensor
24 |
25 |
26 | def tokenize_captions(tokenizer, captions):
27 | inputs = tokenizer(
28 | captions,
29 | max_length=tokenizer.model_max_length,
30 | padding="max_length",
31 | truncation=True,
32 | return_tensors="pt"
33 | )
34 | return inputs.input_ids
35 |
36 |
37 | def setup_savedir(cfg):
38 | time_str = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
39 | if cfg.exp_name is None:
40 | savedir = f"results/animations/exp-{time_str}"
41 | else:
42 | savedir = f"results/animations/{cfg.exp_name}-{time_str}"
43 |
44 | os.makedirs(savedir, exist_ok=True)
45 |
46 | return savedir
47 |
48 |
49 | def setup_guidance_encoder(cfg):
50 | guidance_encoder_group = dict()
51 |
52 | if cfg.weight_dtype == "fp16":
53 | weight_dtype = torch.float16
54 | else:
55 | weight_dtype = torch.float32
56 | weight_dtype = torch.float32
57 | # ['depth', 'normal', 'dwpose']
58 | for guidance_type in cfg.guidance_types:
59 | guidance_encoder_group[guidance_type] = GuidanceEncoder(
60 | guidance_embedding_channels=cfg.guidance_encoder_kwargs.guidance_embedding_channels,
61 | guidance_input_channels=cfg.guidance_encoder_kwargs.guidance_input_channels,
62 | block_out_channels=cfg.guidance_encoder_kwargs.block_out_channels,
63 | ).to(device="cuda", dtype=weight_dtype)
64 |
65 | return guidance_encoder_group
66 |
67 |
68 |
69 | def combine_guidance_data(cfg):
70 | guidance_types = cfg.guidance_types
71 | guidance_data_folder = cfg.data.guidance_data_folder
72 |
73 | guidance_pil_group = dict()
74 | for guidance_type in guidance_types:
75 | guidance_pil_group[guidance_type] = []
76 | guidance_image_lst = sorted(
77 | Path(osp.join(guidance_data_folder, guidance_type)).iterdir()
78 | )
79 | guidance_image_lst = (
80 | guidance_image_lst
81 | if not cfg.data.frame_range
82 | else guidance_image_lst[cfg.data.frame_range[0]:cfg.data.frame_range[1]]
83 | )
84 |
85 | for guidance_image_path in guidance_image_lst:
86 | guidance_pil_group[guidance_type] += [
87 | Image.open(guidance_image_path).convert("RGB")
88 | ]
89 |
90 | first_guidance_length = len(list(guidance_pil_group.values())[0])
91 | assert all(
92 | len(sublist) == first_guidance_length
93 | for sublist in list(guidance_pil_group.values())
94 | )
95 |
96 | return guidance_pil_group, first_guidance_length
97 |
98 |
99 |
100 | def inference(
101 | cfg,
102 | vae,
103 | text_encoder,
104 | tokenizer,
105 | model,
106 | scheduler,
107 | ref_image_pil,
108 | guidance_pil_group,
109 | video_length,
110 | width,
111 | height,
112 | device="cuda",
113 | dtype=torch.float16,
114 | ):
115 | reference_unet = model.reference_unet
116 | denoising_unet = model.denoising_unet
117 | guidance_types = cfg.guidance_types
118 | guidance_encoder_group = {
119 | f"guidance_encoder_{g}": getattr(model, f"guidance_encoder_{g}")
120 | for g in guidance_types
121 | }
122 |
123 | generator = torch.Generator(device=device)
124 | generator.manual_seed(cfg.seed)
125 | pipeline = MultiGuidance2LongVideoPipeline(
126 | vae=vae,
127 | reference_unet=reference_unet,
128 | denoising_unet=denoising_unet,
129 | **guidance_encoder_group,
130 | scheduler=scheduler,
131 | guidance_process_size=cfg.data.get("guidance_process_size", None)
132 | )
133 | pipeline = pipeline.to(device, dtype)
134 |
135 | prompt = "A close up of a person."
136 | prompt_embeds = text_encoder(tokenize_captions(tokenizer, [prompt]).to('cuda'))[0]
137 |
138 | video = pipeline(
139 | ref_image_pil,
140 | prompt_embeds,
141 | guidance_pil_group,
142 | width,
143 | height,
144 | video_length,
145 | num_inference_steps=cfg.num_inference_steps,
146 | guidance_scale=cfg.guidance_scale,
147 | generator=generator,
148 | ).videos
149 |
150 | del pipeline
151 | torch.cuda.empty_cache()
152 |
153 | return video
154 |
155 |
156 | def main(cfg):
157 | logging.basicConfig(
158 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
159 | datefmt="%m/%d/%Y %H:%M:%S",
160 | level=logging.INFO,
161 | )
162 |
163 | save_dir = setup_savedir(cfg)
164 | logging.info(f"Running inference ...")
165 | weight_dtype = torch.float32
166 |
167 | sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
168 | # {'num_train_timesteps': 1000, 'beta_start': 0.00085, 'beta_end': 0.012,
169 | # 'beta_schedule': 'linear', 'steps_offset': 1, 'clip_sample': False}
170 | if cfg.enable_zero_snr:
171 | sched_kwargs.update(
172 | rescale_betas_zero_snr=True,
173 | timestep_spacing="trailing",
174 | prediction_type="v_prediction",
175 | )
176 | noise_scheduler = DDIMScheduler(**sched_kwargs)
177 | sched_kwargs.update({"beta_schedule": "scaled_linear"})
178 |
179 | vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
180 | dtype=weight_dtype, device="cuda"
181 | )
182 |
183 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
184 | cfg.base_model_path,
185 | cfg.motion_module_path,
186 | subfolder="unet",
187 | unet_additional_kwargs=cfg.unet_additional_kwargs,
188 | ).to(dtype=weight_dtype, device="cuda")
189 |
190 | reference_unet = UNet2DConditionModel.from_pretrained(
191 | cfg.base_model_path,
192 | subfolder="unet",
193 | ).to(device="cuda", dtype=weight_dtype)
194 |
195 | text_encoder = CLIPTextModel.from_pretrained(
196 | cfg.base_model_path,
197 | subfolder="text_encoder",
198 | ).to(device="cuda")
199 |
200 | tokenizer = CLIPTokenizer.from_pretrained(
201 | cfg.base_model_path,
202 | subfolder="tokenizer",
203 | )
204 |
205 | guidance_encoder_group = setup_guidance_encoder(cfg)
206 |
207 |
208 | ckpt_dir = cfg.ckpt_dir
209 | denoising_unet.load_state_dict(
210 | torch.load(
211 | osp.join(ckpt_dir, f"denoising_unet.pth"),
212 | map_location="cpu",
213 | ),
214 | strict=False,
215 | )
216 | reference_unet.load_state_dict(
217 | torch.load(
218 | osp.join(ckpt_dir, f"reference_unet.pth"),
219 | map_location="cpu",
220 | ),
221 | strict=False,
222 | )
223 |
224 | for guidance_type, guidance_encoder_module in guidance_encoder_group.items():
225 | guidance_encoder_module.load_state_dict(
226 | torch.load(
227 | osp.join(ckpt_dir, f"guidance_encoder_{guidance_type}.pth"),
228 | map_location="cpu",
229 | ),
230 | strict=False,
231 | )
232 |
233 | reference_control_writer = ReferenceAttentionControl(
234 | reference_unet,
235 | do_classifier_free_guidance=False,
236 | mode="write",
237 | fusion_blocks="full",
238 | )
239 | reference_control_reader = ReferenceAttentionControl(
240 | denoising_unet,
241 | do_classifier_free_guidance=False,
242 | mode="read",
243 | fusion_blocks="full",
244 | )
245 |
246 | model = MgpModel(
247 | reference_unet=reference_unet,
248 | denoising_unet=denoising_unet,
249 | reference_control_writer=reference_control_writer,
250 | reference_control_reader=reference_control_reader,
251 | guidance_encoder_group=guidance_encoder_group,
252 | ).to("cuda", dtype=weight_dtype)
253 |
254 | if cfg.enable_xformers_memory_efficient_attention:
255 | if is_xformers_available():
256 | reference_unet.enable_xformers_memory_efficient_attention()
257 | denoising_unet.enable_xformers_memory_efficient_attention()
258 | else:
259 | raise ValueError(
260 | "xformers is not available. Make sure it is installed correctly"
261 | )
262 |
263 | ref_image_path = cfg.data.ref_image_path
264 | ref_image_pil = Image.open(ref_image_path)
265 | ref_image_w, ref_image_h = ref_image_pil.size
266 |
267 |
268 | guidance_pil_group, video_length = combine_guidance_data(cfg)
269 |
270 | result_video_tensor = inference(
271 | cfg=cfg,
272 | vae=vae,
273 | text_encoder=text_encoder,
274 | tokenizer=tokenizer,
275 | model=model,
276 | scheduler=noise_scheduler,
277 | ref_image_pil=ref_image_pil,
278 | guidance_pil_group=guidance_pil_group,
279 | video_length=video_length,
280 | width=cfg.width,
281 | height=cfg.height,
282 | device="cuda",
283 | dtype=weight_dtype,
284 | ) # (1, c, f, h, w)
285 |
286 | result_video_tensor = resize_tensor_frames(
287 | result_video_tensor, (ref_image_h, ref_image_w)
288 | )
289 | save_videos_grid(result_video_tensor, osp.join(save_dir, "animation.mp4"))
290 |
291 | ref_video_tensor = transforms.ToTensor()(ref_image_pil)[None, :, None, ...].repeat(
292 | 1, 1, video_length, 1, 1
293 | )
294 | guidance_video_tensor_lst = []
295 | for guidance_pil_lst in guidance_pil_group.values():
296 | guidance_video_tensor_lst += [
297 | pil_list_to_tensor(guidance_pil_lst, size=(ref_image_h, ref_image_w))
298 | ]
299 | guidance_video_tensor = torch.stack(guidance_video_tensor_lst, dim=0)
300 |
301 | grid_video = torch.cat([ref_video_tensor, result_video_tensor], dim=0)
302 | grid_video_wguidance = torch.cat(
303 | [ref_video_tensor, result_video_tensor, guidance_video_tensor], dim=0
304 | )
305 |
306 | save_videos_grid(grid_video, osp.join(save_dir, "grid.mp4"))
307 | save_videos_grid(grid_video_wguidance, osp.join(save_dir, "guidance.mp4"))
308 |
309 | logging.info(f"Inference completed, results saved in {save_dir}")
310 |
311 |
312 | if __name__ == "__main__":
313 |
314 | parser = argparse.ArgumentParser()
315 | parser.add_argument("--config", type=str, default="./configs/inference/inference.yaml")
316 | args = parser.parse_args()
317 |
318 | if args.config[-5:] == ".yaml":
319 | cfg = OmegaConf.load(args.config)
320 | else:
321 | raise ValueError("Do not support this format config file")
322 |
323 | main(cfg)
324 |
--------------------------------------------------------------------------------
/models/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/exp_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/exp_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/guidance_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/guidance_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mgportrait_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/mgportrait_model.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/motion_module.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/motion_module.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mutual_self_attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/mutual_self_attention.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/resnet.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/transformer_2d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/transformer_2d.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/transformer_3d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/transformer_3d.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet_2d_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/unet_2d_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet_2d_condition.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/unet_2d_condition.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet_3d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/unet_3d.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet_3d_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/models/__pycache__/unet_3d_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/models/exp_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 |
7 | from diffusers.models.modeling_utils import ModelMixin
8 | from diffusers.utils import BaseOutput
9 | from dataclasses import dataclass
10 |
11 |
12 |
13 |
14 | class ExpEncoder(ModelMixin):
15 | def __init__(
16 | self,
17 | input_size: int = 64,
18 | hidden_sizes=None,
19 | output_size: int = 768,
20 |
21 | ):
22 | super().__init__()
23 | if hidden_sizes is None:
24 | hidden_sizes = [128, 256, 512]
25 | self.layers = nn.ModuleList()
26 | self.layers.append(nn.Linear(input_size, hidden_sizes[0]))
27 | for i in range(1, len(hidden_sizes)):
28 | self.layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
29 | self.layers.append(nn.Linear(hidden_sizes[-1], output_size))
30 |
31 | def forward(self, x):
32 | for layer in self.layers[:-1]:
33 | x = F.relu(layer(x))
34 | # 输出层不加激活函数
35 | x = self.layers[-1](x)
36 |
37 | return x
38 |
--------------------------------------------------------------------------------
/models/guidance_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 |
7 | from diffusers.models.modeling_utils import ModelMixin
8 | from diffusers.utils import BaseOutput
9 | from dataclasses import dataclass
10 | from transformers import CLIPVisionModelWithProjection
11 |
12 | from models.motion_module import zero_module
13 | from models.resnet import InflatedConv3d, InflatedGroupNorm
14 | from models.attention import TemporalBasicTransformerBlock
15 | from models.transformer_3d import Transformer3DModel
16 |
17 |
18 | class GuidanceEncoder(ModelMixin):
19 | def __init__(
20 | self,
21 | guidance_embedding_channels: int,
22 | guidance_input_channels: int = 3,
23 | block_out_channels: Tuple[int] = (16, 32, 96, 256),
24 | attention_num_heads: int = 8,
25 | ):
26 | super().__init__()
27 | self.guidance_input_channels = guidance_input_channels
28 | self.conv_in = InflatedConv3d(
29 | guidance_input_channels, block_out_channels[0], kernel_size=3, padding=1
30 | )
31 |
32 | self.blocks = nn.ModuleList([])
33 | self.attentions = nn.ModuleList([])
34 |
35 | for i in range(len(block_out_channels) - 1):
36 | channel_in = block_out_channels[i]
37 | channel_out = block_out_channels[i + 1]
38 |
39 | self.blocks.append(
40 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
41 | )
42 | self.attentions.append(
43 | Transformer3DModel(
44 | attention_num_heads,
45 | channel_in // attention_num_heads,
46 | channel_in,
47 | norm_num_groups=1,
48 | unet_use_cross_frame_attention=False,
49 | unet_use_temporal_attention=False,
50 | )
51 | )
52 |
53 | self.blocks.append(
54 | InflatedConv3d(
55 | channel_in, channel_out, kernel_size=3, padding=1, stride=2
56 | )
57 | )
58 | self.attentions.append(
59 | Transformer3DModel(
60 | attention_num_heads,
61 | channel_out // attention_num_heads,
62 | channel_out,
63 | norm_num_groups=32,
64 | unet_use_cross_frame_attention=False,
65 | unet_use_temporal_attention=False,
66 | )
67 | )
68 |
69 | attention_channel_out = block_out_channels[-1]
70 | self.guidance_attention = Transformer3DModel(
71 | attention_num_heads,
72 | attention_channel_out // attention_num_heads,
73 | attention_channel_out,
74 | norm_num_groups=32,
75 | unet_use_cross_frame_attention=False,
76 | unet_use_temporal_attention=False,
77 | )
78 |
79 | self.conv_out = zero_module(
80 | InflatedConv3d(
81 | block_out_channels[-1],
82 | guidance_embedding_channels,
83 | kernel_size=3,
84 | padding=1,
85 | )
86 | )
87 |
88 | def forward(self, condition):
89 |
90 | embedding = self.conv_in(condition)
91 | embedding = F.silu(embedding)
92 |
93 | for block in self.blocks:
94 | embedding = block(embedding)
95 | embedding = F.silu(embedding)
96 |
97 | embedding = self.attentions[-1](embedding).sample
98 | embedding = self.conv_out(embedding)
99 | return embedding
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/models/mgportrait_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.unet_2d_condition import UNet2DConditionModel
4 | from models.unet_3d import UNet3DConditionModel
5 | from models.exp_encoder import ExpEncoder
6 | from einops import rearrange
7 |
8 | class MgpModel(nn.Module):
9 | def __init__(
10 | self,
11 | reference_unet: UNet2DConditionModel,
12 | denoising_unet: UNet3DConditionModel,
13 | reference_control_writer,
14 | reference_control_reader,
15 | guidance_encoder_group,
16 | ):
17 | super().__init__()
18 | self.reference_unet = reference_unet
19 | self.denoising_unet = denoising_unet
20 |
21 | self.reference_control_writer = reference_control_writer
22 | self.reference_control_reader = reference_control_reader
23 |
24 | self.guidance_types = []
25 | self.guidance_input_channels = []
26 |
27 | for guidance_type, guidance_module in guidance_encoder_group.items():
28 | setattr(self, f"guidance_encoder_{guidance_type}", guidance_module)
29 | self.guidance_types.append(guidance_type)
30 | self.guidance_input_channels.append(guidance_module.guidance_input_channels)
31 |
32 | def forward(
33 | self,
34 | noisy_latents,
35 | timesteps,
36 | ref_image_latents,
37 | multi_guidance_cond,
38 | text_embeds,
39 | uncond_fwd: bool = False,
40 | ):
41 | guidance_cond_group = torch.split(
42 | multi_guidance_cond, self.guidance_input_channels, dim=1
43 | )
44 | guidance_fea_lst = []
45 | for guidance_idx, guidance_cond in enumerate(guidance_cond_group):
46 | guidance_encoder = getattr(
47 | self, f"guidance_encoder_{self.guidance_types[guidance_idx]}"
48 | )
49 | guidance_fea = guidance_encoder(guidance_cond)
50 | guidance_fea_lst += [guidance_fea]
51 | guidance_fea = torch.stack(guidance_fea_lst, dim=0).sum(0)
52 |
53 | # video_length = exp_embedding.shape[1]
54 | # exp_embedding = rearrange(exp_embedding, "b f d -> (b f) d")
55 | # exp_embed = self.exp_encoder(exp_embedding)
56 | # exp_embed = rearrange(exp_embed, "(b f) d -> b f d", f=video_length)
57 |
58 |
59 | if not uncond_fwd:
60 | ref_timesteps = torch.zeros_like(timesteps)
61 | self.reference_unet(
62 | ref_image_latents,
63 | ref_timesteps,
64 | encoder_hidden_states=text_embeds, # [4, 77, 768]
65 | return_dict=False,
66 | )
67 |
68 | self.reference_control_reader.update(self.reference_control_writer)
69 |
70 | model_pred = self.denoising_unet(
71 | noisy_latents,
72 | timesteps,
73 | guidance_fea=guidance_fea,
74 | encoder_hidden_states=text_embeds, # [btz, frames, dim]
75 | ).sample
76 |
77 | return model_pred
78 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 |
9 | class InflatedConv3d(nn.Conv2d):
10 | def forward(self, x):
11 | video_length = x.shape[2]
12 |
13 | x = rearrange(x, "b c f h w -> (b f) c h w")
14 | x = super().forward(x)
15 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16 |
17 | return x
18 |
19 |
20 | class InflatedGroupNorm(nn.GroupNorm):
21 | def forward(self, x):
22 | video_length = x.shape[2]
23 |
24 | x = rearrange(x, "b c f h w -> (b f) c h w")
25 | x = super().forward(x)
26 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27 |
28 | return x
29 |
30 |
31 | class Upsample3D(nn.Module):
32 | def __init__(
33 | self,
34 | channels,
35 | use_conv=False,
36 | use_conv_transpose=False,
37 | out_channels=None,
38 | name="conv",
39 | ):
40 | super().__init__()
41 | self.channels = channels
42 | self.out_channels = out_channels or channels
43 | self.use_conv = use_conv
44 | self.use_conv_transpose = use_conv_transpose
45 | self.name = name
46 |
47 | conv = None
48 | if use_conv_transpose:
49 | raise NotImplementedError
50 | elif use_conv:
51 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52 |
53 | def forward(self, hidden_states, output_size=None):
54 | assert hidden_states.shape[1] == self.channels
55 |
56 | if self.use_conv_transpose:
57 | raise NotImplementedError
58 |
59 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60 | dtype = hidden_states.dtype
61 | if dtype == torch.bfloat16:
62 | hidden_states = hidden_states.to(torch.float32)
63 |
64 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65 | if hidden_states.shape[0] >= 64:
66 | hidden_states = hidden_states.contiguous()
67 |
68 | # if `output_size` is passed we force the interpolation output
69 | # size and do not make use of `scale_factor=2`
70 | if output_size is None:
71 | hidden_states = F.interpolate(
72 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73 | )
74 | else:
75 | hidden_states = F.interpolate(
76 | hidden_states, size=output_size, mode="nearest"
77 | )
78 |
79 | # If the input is bfloat16, we cast back to bfloat16
80 | if dtype == torch.bfloat16:
81 | hidden_states = hidden_states.to(dtype)
82 |
83 | # if self.use_conv:
84 | # if self.name == "conv":
85 | # hidden_states = self.conv(hidden_states)
86 | # else:
87 | # hidden_states = self.Conv2d_0(hidden_states)
88 | hidden_states = self.conv(hidden_states)
89 |
90 | return hidden_states
91 |
92 |
93 | class Downsample3D(nn.Module):
94 | def __init__(
95 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96 | ):
97 | super().__init__()
98 | self.channels = channels
99 | self.out_channels = out_channels or channels
100 | self.use_conv = use_conv
101 | self.padding = padding
102 | stride = 2
103 | self.name = name
104 |
105 | if use_conv:
106 | self.conv = InflatedConv3d(
107 | self.channels, self.out_channels, 3, stride=stride, padding=padding
108 | )
109 | else:
110 | raise NotImplementedError
111 |
112 | def forward(self, hidden_states):
113 | assert hidden_states.shape[1] == self.channels
114 | if self.use_conv and self.padding == 0:
115 | raise NotImplementedError
116 |
117 | assert hidden_states.shape[1] == self.channels
118 | hidden_states = self.conv(hidden_states)
119 |
120 | return hidden_states
121 |
122 |
123 | class ResnetBlock3D(nn.Module):
124 | def __init__(
125 | self,
126 | *,
127 | in_channels,
128 | out_channels=None,
129 | conv_shortcut=False,
130 | dropout=0.0,
131 | temb_channels=512,
132 | groups=32,
133 | groups_out=None,
134 | pre_norm=True,
135 | eps=1e-6,
136 | non_linearity="swish",
137 | time_embedding_norm="default",
138 | output_scale_factor=1.0,
139 | use_in_shortcut=None,
140 | use_inflated_groupnorm=None,
141 | ):
142 | super().__init__()
143 | self.pre_norm = pre_norm
144 | self.pre_norm = True
145 | self.in_channels = in_channels
146 | out_channels = in_channels if out_channels is None else out_channels
147 | self.out_channels = out_channels
148 | self.use_conv_shortcut = conv_shortcut
149 | self.time_embedding_norm = time_embedding_norm
150 | self.output_scale_factor = output_scale_factor
151 |
152 | if groups_out is None:
153 | groups_out = groups
154 |
155 | assert use_inflated_groupnorm != None
156 | if use_inflated_groupnorm:
157 | self.norm1 = InflatedGroupNorm(
158 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159 | )
160 | else:
161 | self.norm1 = torch.nn.GroupNorm(
162 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163 | )
164 |
165 | self.conv1 = InflatedConv3d(
166 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
167 | )
168 |
169 | if temb_channels is not None:
170 | if self.time_embedding_norm == "default":
171 | time_emb_proj_out_channels = out_channels
172 | elif self.time_embedding_norm == "scale_shift":
173 | time_emb_proj_out_channels = out_channels * 2
174 | else:
175 | raise ValueError(
176 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
177 | )
178 |
179 | self.time_emb_proj = torch.nn.Linear(
180 | temb_channels, time_emb_proj_out_channels
181 | )
182 | else:
183 | self.time_emb_proj = None
184 |
185 | if use_inflated_groupnorm:
186 | self.norm2 = InflatedGroupNorm(
187 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188 | )
189 | else:
190 | self.norm2 = torch.nn.GroupNorm(
191 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192 | )
193 | self.dropout = torch.nn.Dropout(dropout)
194 | self.conv2 = InflatedConv3d(
195 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
196 | )
197 |
198 | if non_linearity == "swish":
199 | self.nonlinearity = lambda x: F.silu(x)
200 | elif non_linearity == "mish":
201 | self.nonlinearity = Mish()
202 | elif non_linearity == "silu":
203 | self.nonlinearity = nn.SiLU()
204 |
205 | self.use_in_shortcut = (
206 | self.in_channels != self.out_channels
207 | if use_in_shortcut is None
208 | else use_in_shortcut
209 | )
210 |
211 | self.conv_shortcut = None
212 | if self.use_in_shortcut:
213 | self.conv_shortcut = InflatedConv3d(
214 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
215 | )
216 |
217 | def forward(self, input_tensor, temb):
218 | hidden_states = input_tensor
219 |
220 | hidden_states = self.norm1(hidden_states)
221 | hidden_states = self.nonlinearity(hidden_states)
222 |
223 | hidden_states = self.conv1(hidden_states)
224 |
225 | if temb is not None:
226 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227 |
228 | if temb is not None and self.time_embedding_norm == "default":
229 | hidden_states = hidden_states + temb
230 |
231 | hidden_states = self.norm2(hidden_states)
232 |
233 | if temb is not None and self.time_embedding_norm == "scale_shift":
234 | scale, shift = torch.chunk(temb, 2, dim=1)
235 | hidden_states = hidden_states * (1 + scale) + shift
236 |
237 | hidden_states = self.nonlinearity(hidden_states)
238 |
239 | hidden_states = self.dropout(hidden_states)
240 | hidden_states = self.conv2(hidden_states)
241 |
242 | if self.conv_shortcut is not None:
243 | input_tensor = self.conv_shortcut(input_tensor)
244 |
245 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246 |
247 | return output_tensor
248 |
249 |
250 | class Mish(torch.nn.Module):
251 | def forward(self, hidden_states):
252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253 |
--------------------------------------------------------------------------------
/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models import ModelMixin
7 | from diffusers.utils import BaseOutput
8 | from diffusers.utils.import_utils import is_xformers_available
9 | from einops import rearrange, repeat
10 | from torch import nn
11 |
12 | from .attention import TemporalBasicTransformerBlock
13 |
14 |
15 | @dataclass
16 | class Transformer3DModelOutput(BaseOutput):
17 | sample: torch.FloatTensor
18 |
19 |
20 | if is_xformers_available():
21 | import xformers
22 | import xformers.ops
23 | else:
24 | xformers = None
25 |
26 |
27 | class Transformer3DModel(ModelMixin, ConfigMixin):
28 | _supports_gradient_checkpointing = True
29 |
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(
59 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60 | )
61 | if use_linear_projection:
62 | self.proj_in = nn.Linear(in_channels, inner_dim)
63 | else:
64 | self.proj_in = nn.Conv2d(
65 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66 | )
67 |
68 | # Define transformers blocks
69 | self.transformer_blocks = nn.ModuleList(
70 | [
71 | TemporalBasicTransformerBlock(
72 | inner_dim,
73 | num_attention_heads,
74 | attention_head_dim,
75 | dropout=dropout,
76 | cross_attention_dim=cross_attention_dim,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | attention_bias=attention_bias,
80 | only_cross_attention=only_cross_attention,
81 | upcast_attention=upcast_attention,
82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83 | unet_use_temporal_attention=unet_use_temporal_attention,
84 | )
85 | for d in range(num_layers)
86 | ]
87 | )
88 |
89 | # 4. Define output layers
90 | if use_linear_projection:
91 | self.proj_out = nn.Linear(in_channels, inner_dim)
92 | else:
93 | self.proj_out = nn.Conv2d(
94 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95 | )
96 |
97 | self.gradient_checkpointing = False
98 |
99 | def _set_gradient_checkpointing(self, module, value=False):
100 | if hasattr(module, "gradient_checkpointing"):
101 | module.gradient_checkpointing = value
102 |
103 | def forward(
104 | self,
105 | hidden_states,
106 | encoder_hidden_states=None,
107 | timestep=None,
108 | return_dict: bool = True,
109 | ):
110 | # Input
111 | assert (
112 | hidden_states.dim() == 5
113 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114 | video_length = hidden_states.shape[2]
115 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116 | if encoder_hidden_states is not None:
117 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
118 | encoder_hidden_states = repeat(
119 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
120 | )
121 |
122 | batch, channel, height, weight = hidden_states.shape
123 | residual = hidden_states
124 |
125 | hidden_states = self.norm(hidden_states)
126 | if not self.use_linear_projection:
127 | hidden_states = self.proj_in(hidden_states)
128 | inner_dim = hidden_states.shape[1]
129 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
130 | batch, height * weight, inner_dim
131 | )
132 | else:
133 | inner_dim = hidden_states.shape[1]
134 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
135 | batch, height * weight, inner_dim
136 | )
137 | hidden_states = self.proj_in(hidden_states)
138 |
139 | # Blocks
140 | for i, block in enumerate(self.transformer_blocks):
141 | hidden_states = block(
142 | hidden_states,
143 | encoder_hidden_states=encoder_hidden_states,
144 | timestep=timestep,
145 | video_length=video_length,
146 | )
147 |
148 | # Output
149 | if not self.use_linear_projection:
150 | hidden_states = (
151 | hidden_states.reshape(batch, height, weight, inner_dim)
152 | .permute(0, 3, 1, 2)
153 | .contiguous()
154 | )
155 | hidden_states = self.proj_out(hidden_states)
156 | else:
157 | hidden_states = self.proj_out(hidden_states)
158 | hidden_states = (
159 | hidden_states.reshape(batch, height, weight, inner_dim)
160 | .permute(0, 3, 1, 2)
161 | .contiguous()
162 | )
163 |
164 | output = hidden_states + residual
165 |
166 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
167 | if not return_dict:
168 | return (output,)
169 |
170 | return Transformer3DModelOutput(sample=output)
171 |
--------------------------------------------------------------------------------
/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__init__.py
--------------------------------------------------------------------------------
/pipelines/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/pipelines/__pycache__/context.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__pycache__/context.cpython-310.pyc
--------------------------------------------------------------------------------
/pipelines/__pycache__/pipe_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__pycache__/pipe_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/pipelines/__pycache__/pipeline_aggregation.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/pipelines/__pycache__/pipeline_aggregation.cpython-310.pyc
--------------------------------------------------------------------------------
/pipelines/context.py:
--------------------------------------------------------------------------------
1 | # TODO: Adapted from cli
2 | from typing import Callable, List, Optional
3 |
4 | import numpy as np
5 |
6 |
7 | def ordered_halving(val):
8 | bin_str = f"{val:064b}"
9 | bin_flip = bin_str[::-1]
10 | as_int = int(bin_flip, 2)
11 |
12 | return as_int / (1 << 64)
13 |
14 |
15 | def uniform(
16 | step: int = ...,
17 | num_steps: Optional[int] = None,
18 | num_frames: int = ...,
19 | context_size: Optional[int] = None,
20 | context_stride: int = 3,
21 | context_overlap: int = 4,
22 | closed_loop: bool = True,
23 | ):
24 | if num_frames <= context_size:
25 | yield list(range(num_frames))
26 | return
27 |
28 | context_stride = min(
29 | context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30 | )
31 |
32 | for context_step in 1 << np.arange(context_stride):
33 | pad = int(round(num_frames * ordered_halving(step)))
34 | for j in range(
35 | int(ordered_halving(step) * context_step) + pad,
36 | num_frames + pad + (0 if closed_loop else -context_overlap),
37 | (context_size * context_step - context_overlap),
38 | ):
39 | yield [
40 | e % num_frames
41 | for e in range(j, j + context_size * context_step, context_step)
42 | ]
43 |
44 |
45 | def get_context_scheduler(name: str) -> Callable:
46 | if name == "uniform":
47 | return uniform
48 | else:
49 | raise ValueError(f"Unknown context_overlap policy {name}")
50 |
51 |
52 | def get_total_steps(
53 | scheduler,
54 | timesteps: List[int],
55 | num_steps: Optional[int] = None,
56 | num_frames: int = ...,
57 | context_size: Optional[int] = None,
58 | context_stride: int = 3,
59 | context_overlap: int = 4,
60 | closed_loop: bool = True,
61 | ):
62 | return sum(
63 | len(
64 | list(
65 | scheduler(
66 | i,
67 | num_steps,
68 | num_frames,
69 | context_size,
70 | context_stride,
71 | context_overlap,
72 | )
73 | )
74 | )
75 | for i in range(len(timesteps))
76 | )
77 |
--------------------------------------------------------------------------------
/pipelines/pipe_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | tensor_interpolation = None
4 |
5 |
6 | def get_tensor_interpolation_method():
7 | return tensor_interpolation
8 |
9 |
10 | def set_tensor_interpolation_method(is_slerp):
11 | global tensor_interpolation
12 | tensor_interpolation = slerp if is_slerp else linear
13 |
14 |
15 | def linear(v1, v2, t):
16 | return (1.0 - t) * v1 + t * v2
17 |
18 |
19 | def slerp(
20 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21 | ) -> torch.Tensor:
22 | u0 = v0 / v0.norm()
23 | u1 = v1 / v1.norm()
24 | dot = (u0 * u1).sum()
25 | if dot.abs() > DOT_THRESHOLD:
26 | # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27 | return (1.0 - t) * v0 + t * v1
28 | omega = dot.acos()
29 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
30 |
--------------------------------------------------------------------------------
/render_and_transfer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch as th
5 | from torchvision.utils import save_image
6 | import argparse
7 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".")))
8 |
9 |
10 | from decalib.deca_with_smirk import DECA
11 | from decalib.utils.config import cfg as deca_cfg
12 | from data_utils.transfer_utils import get_image_dict
13 |
14 | # Build DECA
15 | deca_cfg.model.use_tex = True
16 | deca_cfg.model.tex_path = "./decalib/data/FLAME_texture.npz"
17 | deca_cfg.model.tex_type = "FLAME"
18 | deca = DECA(config=deca_cfg, device="cuda")
19 |
20 |
21 |
22 | def get_render(source, target, save_file):
23 |
24 | src_dict = get_image_dict(source, 512, True)
25 | tar_dict = get_image_dict(target, 512, True)
26 | # ===================get DECA codes of the target image===============================
27 | tar_cropped = tar_dict["image"].unsqueeze(0).to("cuda")
28 | imgname = tar_dict["imagename"]
29 |
30 | with th.no_grad():
31 | tar_code = deca.encode(tar_cropped)
32 | tar_image = tar_dict["original_image"].unsqueeze(0).to("cuda")
33 | # ===================get DECA codes of the source image===============================
34 | src_cropped = src_dict["image"].unsqueeze(0).to("cuda")
35 | with th.no_grad():
36 | src_code = deca.encode(src_cropped)
37 | # To align the face when the pose is changing
38 | src_ffhq_center = deca.decode(src_code, return_ffhq_center=True)
39 | tar_ffhq_center = deca.decode(tar_code, return_ffhq_center=True)
40 |
41 | src_tform = src_dict["tform"].unsqueeze(0)
42 | src_tform = th.inverse(src_tform).transpose(1, 2).to("cuda")
43 | src_code["tform"] = src_tform
44 |
45 | tar_tform = tar_dict["tform"].unsqueeze(0)
46 | tar_tform = th.inverse(tar_tform).transpose(1, 2).to("cuda")
47 | tar_code["tform"] = tar_tform
48 |
49 | src_image = src_dict["original_image"].unsqueeze(0).to("cuda") # 平均的参数
50 | tar_image = tar_dict["original_image"].unsqueeze(0).to("cuda")
51 |
52 | # code 1 means source code, code 2 means target code
53 | code1, code2 = {}, {}
54 | for k in src_code:
55 | code1[k] = src_code[k].clone()
56 |
57 | for k in tar_code:
58 | code2[k] = tar_code[k].clone()
59 |
60 | code1["pose"][:, :3] = code2["pose"][:, :3]
61 | code1['exp'] = code2['exp']
62 | code1['pose'][:, 3:] = tar_code['pose'][:, 3:]
63 |
64 | opdict, _ = deca.decode(
65 | code1,
66 | render_orig=True,
67 | original_image=tar_image,
68 | tform=src_code["tform"],
69 | align_ffhq=False,
70 | ffhq_center=src_ffhq_center,
71 | imgpath=target
72 | )
73 |
74 | depth = opdict["depth_images"].detach()
75 | normal = opdict["normal_images"].detach()
76 | render = opdict["rendered_images"].detach()
77 | os.makedirs(f'./transfers/{save_file}/depth', exist_ok=True)
78 | os.makedirs(f'./transfers/{save_file}/normal', exist_ok=True)
79 | os.makedirs(f'./transfers/{save_file}/render', exist_ok=True)
80 |
81 | save_image(depth[0], f"./transfers/{save_file}/depth/{imgname}")
82 | save_image(normal[0], f"./transfers/{save_file}/normal/{imgname}")
83 | save_image(render[0], f"./transfers/{save_file}/render/{imgname}")
84 |
85 |
86 | if __name__ == '__main__':
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument(
89 | "--sor_img",
90 | type=str,
91 | default='/home/mengting/projects/diffusionRig/myscripts/papers/example42/boy2_cropped.jpg',
92 | required=False
93 | )
94 | parser.add_argument(
95 | "--driving_path",
96 | type=str,
97 | default='/home/mengting/Desktop/frames_1500_updated/1fsFQ2gF4oE_0/images',
98 | required=False
99 | )
100 | parser.add_argument(
101 | "--save_name",
102 | type=str,
103 | default='example1',
104 | required=False
105 | )
106 |
107 | args = parser.parse_args()
108 | images = sorted(os.listdir(args.driving_path))
109 | for image in images:
110 | cur_image_path = os.path.join(args.driving_path, image)
111 | get_render(args.sor_img, cur_image_path, args.save_name)
112 |
113 | print('done')
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.1
2 | diffusers==0.32.2
3 | einops==0.8.1
4 | face_alignment==1.4.1
5 | facenet_pytorch==2.6.0
6 | imageio==2.37.0
7 | insightface==0.7.3
8 | ipdb==0.13.13
9 | kornia==0.8.0
10 | loguru==0.7.3
11 | mediapipe==0.10.21
12 | moviepy==2.1.2
13 | numpy==2.2.5
14 | omegaconf==2.3.0
15 | opencv_contrib_python==4.11.0.86
16 | opencv_python==4.11.0.86
17 | opencv_python_headless==4.11.0.86
18 | Pillow==11.2.1
19 | pyarrow==19.0.0
20 | pytorch3d==0.7.8
21 | PyYAML==6.0.2
22 | PyYAML==6.0.2
23 | safetensors==0.5.3
24 | scipy==1.15.2
25 | setuptools==60.2.0
26 | setuptools==75.8.0
27 | skimage==0.0
28 | timm==0.6.13
29 | torchfile==0.1.0
30 | tqdm==4.67.1
31 | transformers==4.49.0
32 | xformers==0.0.28
33 | yacs==0.1.8
34 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/video_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weimengting/MagicPortrait/010332ac57f87186f79e3efcbc66e0c9e76e3b10/utils/__pycache__/video_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/download.py:
--------------------------------------------------------------------------------
1 | import urllib
2 | import logging
3 | import os
4 | from pathlib import Path
5 | import urllib.request
6 | import tqdm
7 |
8 |
9 | def download(url: str, output: Path) -> None:
10 | if not os.path.exists(output.parent):
11 | os.makedirs(output.parent, exist_ok=True)
12 |
13 | response = urllib.request.urlopen(url)
14 | content_length = response.info().get("Content-Length")
15 | if content_length is None:
16 | raise ValueError("invalid content length")
17 | content_length = int(content_length)
18 |
19 | if os.path.exists(output):
20 | if os.path.getsize(output) == content_length:
21 | print(f"{output} exists. Download skip.")
22 | return
23 |
24 | saved_size = 0
25 |
26 | pbar = tqdm.tqdm(total=content_length)
27 | with open(output, "wb") as f:
28 | while 1:
29 | chunk = response.read(8192)
30 | if not chunk:
31 | break
32 | f.write(chunk)
33 | saved_size += len(chunk)
34 | pbar.update(len(chunk))
35 |
36 | if saved_size != content_length:
37 | os.remove(output)
38 | raise BlockingIOError("fail to download. file cleared")
39 |
--------------------------------------------------------------------------------
/utils/fs.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Any, Generator
4 |
5 |
6 | def traverse_folder(folder: Path) -> Generator[Path, Any, Any]:
7 | if os.path.exists(folder) and os.path.isdir(folder):
8 | children = sorted(os.listdir(folder))
9 | for child in children:
10 | child_abs_path = Path(os.path.join(folder, child))
11 | if os.path.isfile(child_abs_path):
12 | yield child_abs_path
13 | else:
14 | yield from traverse_folder(child_abs_path)
15 | else:
16 | raise ValueError(f"{folder} does not exist or is not a directory")
17 |
--------------------------------------------------------------------------------
/utils/postprocess.py:
--------------------------------------------------------------------------------
1 | from moviepy import VideoFileClip, concatenate_videoclips, clips_array, ImageClip, ColorClip, concatenate_videoclips, vfx
2 | import os
3 | import cv2
4 | from pyarrow import duration
5 |
6 |
7 | def print_video_info(video, name):
8 | fps = video.fps
9 | duration = video.duration # 秒数
10 | frame_count = int(fps * duration)
11 |
12 | print(f"视频{name}:")
13 | print(f" 帧率: {fps} fps")
14 | print(f" 时长: {duration:.2f} 秒")
15 | print(f" 总帧数: {frame_count} 帧")
16 | print("")
17 |
18 | def images_to_video(image_dir, output_path, fps=24):
19 | # 获取图像文件并按自然顺序排序
20 | image_files = sorted([
21 | f for f in os.listdir(image_dir) if f.endswith((".jpg", ".jpeg", ".png"))])
22 |
23 |
24 | # 读取第一张图像获取尺寸
25 | first_image = cv2.imread(os.path.join(image_dir, image_files[0]))
26 | if first_image is None:
27 | raise ValueError("❌ 无法读取第一张图像。")
28 | height, width = first_image.shape[:2]
29 |
30 | # 设置视频编码器
31 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # .mp4 输出
32 | video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
33 |
34 | for img_name in image_files:
35 | img_path = os.path.join(image_dir, img_name)
36 | frame = cv2.imread(img_path)
37 |
38 | # 自动 resize 不一致的尺寸
39 | if frame.shape[:2] != (height, width):
40 | frame = cv2.resize(frame, (width, height))
41 |
42 | video_writer.write(frame)
43 |
44 | video_writer.release()
45 | print(f"✅ 视频保存成功: {output_path}")
46 |
47 | def concatenate_videos():
48 | # 加载三个视频
49 | clip1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/cut1.mp4")
50 | clip2 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/cut2.mp4")
51 | # clip3 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/cut3.mp4")
52 |
53 | # 拼接视频
54 | final_clip = concatenate_videoclips([clip1, clip2])
55 |
56 | # 输出到新文件
57 | final_clip.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/output_combined.mp4", codec="libx264")
58 |
59 |
60 | def array_videos():
61 | # 加载两个视频
62 | clip1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/driving.mp4")
63 | clip2 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/animation.mp4")
64 |
65 | # 横向拼接(按帧对齐)
66 | final_clip = clips_array([[clip1, clip2]]) # 一行两列,clip1 左,clip2 右
67 |
68 | # 导出结果
69 | final_clip.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/side_by_side.mp4", codec="libx264")
70 |
71 | def set_fps():
72 | clip = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/combined_with_transitions.mp4")
73 |
74 | # 设定新帧率(不会插帧)
75 | clip = clip.with_fps(24)
76 |
77 | # 写出时指定新帧率(帧数不变,播放速度加快)
78 | clip.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/combined_with_transitions_24.mp4", fps=24, codec="libx264")
79 |
80 |
81 | def add_static_image():
82 | # 加载 side-by-side 视频
83 | video = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/output_fast_30fps.mp4")
84 |
85 | # 加载静态图像,并设置和视频一样的高度、持续整个视频时长
86 | image = ImageClip("/home/mengting/Desktop/tmp/champ/gpu_1/boy2_cropped.jpg").resized(height=video.h).with_duration(video.duration)
87 |
88 | # 横向拼接(静止图像 + 视频)
89 | final = clips_array([[image, video]])
90 |
91 | # 导出结果
92 | final.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/final_output.mp4", codec="libx264")
93 | # 179 拼接, 360-179 = 181
94 | # 181和117拼接,181-117 = 64
95 | # 64和150拼接,150-64 = 86
96 | # video 7 and video 8
97 | # video 1 is 275, video2 is 96, video 3 is 360, video5 is 117, video6 is 150, video7 is 98, video8 is 102
98 | def print_info():
99 | # 加载两个视频
100 | video1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person1/video_part1.mp4")
101 | video2 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person2/final_output.mp4")
102 |
103 | fps = video1.fps
104 | duration = video1.duration # 秒
105 | frame_count = int(fps * duration)
106 | print(f" 帧率 (fps): {fps}")
107 | print(f" 总时长 (秒): {duration:.2f}")
108 | print(f" 总帧数: {frame_count}")
109 | print("")
110 |
111 | def split_video():
112 | video = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person8/final_output.mp4")
113 |
114 | # 获取帧率和总帧数
115 | fps = video.fps
116 | total_frames = int(video.duration * fps) # 或直接写 total_frames = 275
117 |
118 | # 设定拆分点
119 | first_part_frames = 98
120 | first_part_duration = first_part_frames / fps # 以秒为单位
121 | total_duration = video.duration
122 |
123 | # 截取第一段(前96帧)
124 | clip1 = video.subclipped(0, first_part_duration)
125 |
126 | # 截取第二段(剩下的帧)
127 | clip2 = video.subclipped(first_part_duration, total_duration)
128 |
129 | # 保存
130 | clip1.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/person8/video_part_1.mp4", codec="libx264", fps=fps)
131 | clip2.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/person8/video_part_2.mp4", codec="libx264", fps=fps)
132 |
133 |
134 | def up_and_down():
135 | # 加载两个视频
136 | clip1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person7/final_output.mp4")
137 | clip2 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/person8/video_part_1.mp4")
138 |
139 |
140 | # 上下拼接(垂直堆叠)
141 | final = clips_array([[clip1], [clip2]])
142 |
143 | # 输出拼接后的视频
144 | final.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_5.mp4", codec="libx264", fps=clip1.fps)
145 |
146 |
147 | def concate_all():
148 | # 视频文件名列表
149 | video_files = ["/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_1.mp4",
150 | "/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_2.mp4",
151 | "/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_3.mp4",
152 | "/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_4.mp4",
153 | "/home/mengting/Desktop/tmp/champ/gpu_1/stacked_vertical_5.mp4"]
154 |
155 | # 设置淡入/淡出时长(秒)
156 | fade_duration = 1.0
157 | black_duration = 0.5
158 |
159 | # 加载视频,添加淡入淡出
160 | clips = []
161 | for idx, file in enumerate(video_files):
162 | clip = VideoFileClip(file)
163 | clip = clip.with_effects([vfx.FadeIn(duration=fade_duration)])
164 | clip = clip.with_effects([vfx.FadeOut(duration=black_duration)])
165 | clips.append(clip)
166 |
167 | # 在中间插入黑色片段(不加最后一个)
168 | if idx < len(video_files) - 1:
169 | black = ColorClip(size=clip.size, color=(0, 0, 0), duration=black_duration)
170 | clips.append(black)
171 |
172 | # 拼接所有片段
173 | final_clip = concatenate_videoclips(clips, method="compose")
174 |
175 | # 导出视频
176 | final_clip.write_videofile("/home/mengting/Desktop/tmp/champ/gpu_1/combined_with_transitions.mp4", codec="libx264", fps=clips[0].fps)
177 |
178 | if __name__ == '__main__':
179 | #concatenate_videos()
180 | #images_to_video('/home/mengting/Desktop/frames_1500_updated/1fsFQ2gF4oE_0/images', '/home/mengting/Desktop/tmp/champ/gpu_1/driving.mp4')
181 | #array_videos()
182 | # video1 = VideoFileClip("/home/mengting/Desktop/tmp/champ/gpu_1/driving.mp4")
183 | # print_video_info(video1, "driving")
184 | set_fps()
185 | #add_static_image()
186 | #print_info()
187 | #split_video()
188 | #up_and_down()
189 | #concate_all()
--------------------------------------------------------------------------------
/utils/tb_tracker.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | from accelerate.tracking import GeneralTracker, on_main_process
3 | import os
4 | from typing import Union
5 |
6 |
7 | class TbTracker(GeneralTracker):
8 |
9 | name = "tensorboard"
10 | requires_logging_directory = True
11 |
12 | @on_main_process
13 | def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike],
14 | **kwargs):
15 | super().__init__()
16 | self.run_name = run_name
17 | self.logging_dir = os.path.join(logging_dir, run_name)
18 | self.writer = SummaryWriter(self.logging_dir, **kwargs)
19 |
20 | @property
21 | def tracker(self):
22 | return self.writer
23 |
24 | @on_main_process
25 | def add_scalar(self, tag, scalar_value, **kwargs):
26 | self.writer.add_scalar(tag=tag, scalar_value=scalar_value, **kwargs)
27 |
28 | @on_main_process
29 | def add_text(self, tag, text_string, **kwargs):
30 | self.writer.add_text(tag=tag, text_string=text_string, **kwargs)
31 |
32 | @on_main_process
33 | def add_figure(self, tag, figure, **kwargs):
34 | self.writer.add_figure(tag=tag, figure=figure, **kwargs)
35 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import shutil
4 | import sys
5 | from pathlib import Path
6 |
7 | # import av
8 | import numpy as np
9 | import torch
10 | import torchvision
11 | from einops import rearrange
12 | from PIL import Image
13 |
14 |
15 | def seed_everything(seed):
16 | import random
17 | import numpy as np
18 |
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | np.random.seed(seed % (2**32))
22 | random.seed(seed)
23 |
24 | def delete_additional_ckpt(base_path, num_keep):
25 | dirs = []
26 | for d in os.listdir(base_path):
27 | if d.startswith("checkpoint-"):
28 | dirs.append(d)
29 | num_tot = len(dirs)
30 | if num_tot <= num_keep:
31 | return
32 | # ensure ckpt is sorted and delete the ealier!
33 | del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
34 | for d in del_dirs:
35 | path_to_dir = osp.join(base_path, d)
36 | if osp.exists(path_to_dir):
37 | shutil.rmtree(path_to_dir)
38 |
39 | def compute_snr(noise_scheduler, timesteps):
40 | """
41 | Computes SNR as per
42 | https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
43 | """
44 | alphas_cumprod = noise_scheduler.alphas_cumprod
45 | sqrt_alphas_cumprod = alphas_cumprod**0.5
46 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
47 |
48 | # Expand the tensors.
49 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
50 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
51 | timesteps
52 | ].float()
53 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
54 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
55 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
56 |
57 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
58 | device=timesteps.device
59 | )[timesteps].float()
60 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
61 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
62 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
63 |
64 | # Compute SNR.
65 | snr = (alpha / sigma) ** 2
66 | return snr
67 |
--------------------------------------------------------------------------------
/utils/video_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torchvision
5 | import torch.nn.functional as F
6 | from PIL import Image
7 | from pathlib import Path
8 | import imageio
9 | from einops import rearrange
10 | import torchvision.transforms as transforms
11 |
12 |
13 | def save_videos_from_pil(pil_images, path, fps=15, crf=23):
14 |
15 | save_fmt = Path(path).suffix
16 | os.makedirs(os.path.dirname(path), exist_ok=True)
17 |
18 | if save_fmt == ".mp4":
19 | with imageio.get_writer(path, fps=fps) as writer:
20 | for img in pil_images:
21 | img_array = np.array(img) # Convert PIL Image to numpy array
22 | writer.append_data(img_array)
23 |
24 | elif save_fmt == ".gif":
25 | pil_images[0].save(
26 | fp=path,
27 | format="GIF",
28 | append_images=pil_images[1:],
29 | save_all=True,
30 | duration=(1 / fps * 1000),
31 | loop=0,
32 | )
33 | else:
34 | raise ValueError("Unsupported file type. Use .mp4 or .gif.")
35 |
36 |
37 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=24):
38 | videos = rearrange(videos, "b c t h w -> t b c h w")
39 | height, width = videos.shape[-2:]
40 | outputs = []
41 |
42 | for x in videos:
43 | x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
44 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
45 | if rescale:
46 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
47 | x = (x * 255).numpy().astype(np.uint8)
48 | x = Image.fromarray(x)
49 |
50 | outputs.append(x)
51 |
52 | os.makedirs(os.path.dirname(path), exist_ok=True)
53 |
54 | save_videos_from_pil(outputs, path, fps)
55 |
56 |
57 | def resize_tensor_frames(video_tensor, new_size):
58 | B, C, video_length, H, W = video_tensor.shape
59 | # Reshape video tensor to combine batch and frame dimensions: (B*F, C, H, W)
60 | video_tensor_reshaped = video_tensor.reshape(-1, C, H, W)
61 | # Resize using interpolate
62 | resized_frames = F.interpolate(
63 | video_tensor_reshaped, size=new_size, mode="bilinear", align_corners=False
64 | )
65 | resized_video = resized_frames.reshape(B, C, video_length, new_size[0], new_size[1])
66 |
67 | return resized_video
68 |
69 |
70 | def pil_list_to_tensor(image_list, size=None):
71 | to_tensor = transforms.ToTensor()
72 | if size is not None:
73 | tensor_list = [to_tensor(img.resize(size[::-1])) for img in image_list]
74 | else:
75 | tensor_list = [to_tensor(img) for img in image_list]
76 | stacked_tensor = torch.stack(tensor_list, dim=0)
77 | tensor = stacked_tensor.permute(1, 0, 2, 3)
78 | return tensor
79 |
80 |
81 | def conatenate_into_video():
82 | gt_list = []
83 | gt_root = '/home/mengting/projects/champ_abls/no_exp_coeff/results/output_images'
84 | imgs = sorted(os.listdir(gt_root))
85 | for img in imgs:
86 | cur_img_path = os.path.join(gt_root, img)
87 | tmp_img = Image.open(cur_img_path)
88 | tmp_img = tmp_img.resize((512, 512))
89 | tmp_img = transforms.ToTensor()(tmp_img)
90 | gt_list.append(tmp_img)
91 | gt_list = torch.stack(gt_list, dim=1).unsqueeze(0)
92 | print(gt_list.shape)
93 |
94 | ref_image_path = '/home/mengting/Desktop/frames_1500_updated/4Z7qKXu9Sck_2/images/frame_0000.jpg'
95 | ref_image_pil = Image.open(ref_image_path)
96 | ref_image_w, ref_image_h = ref_image_pil.size
97 | video_length = len(imgs)
98 | ref_video_tensor = transforms.ToTensor()(ref_image_pil)[None, :, None, ...].repeat(
99 | 1, 1, video_length, 1, 1
100 | )
101 |
102 | drive_list = []
103 | guidance_path = '/home/mengting/Desktop/frames_new_1500/2yj1P52T1X8_4/images'
104 | imgs = sorted(os.listdir(guidance_path))
105 | for i, img in enumerate(imgs):
106 | cur_img_path = os.path.join(guidance_path, img)
107 | tmp_img = Image.open(cur_img_path)
108 | tmp_img = transforms.ToTensor()(tmp_img)
109 | drive_list.append(tmp_img)
110 | if len(drive_list) == video_length:
111 | break
112 | drive_list = torch.stack(drive_list, dim=1).unsqueeze(0)
113 | print(drive_list.shape, ref_video_tensor.shape)
114 |
115 | save_dir = '/home/mengting/projects/champ_abls/no_exp_coeff/results/comparison'
116 | grid_video = torch.cat([drive_list, ref_video_tensor, gt_list], dim=0)
117 | save_videos_grid(grid_video, os.path.join(save_dir, "grid_wdrive_aniportrait.mp4"))
118 |
119 | if __name__ == '__main__':
120 | conatenate_into_video()
--------------------------------------------------------------------------------