├── .gitignore
├── LICENSE
├── README.md
├── base
├── __init__.py
├── baseTrainer.py
├── base_model.py
├── config.py
└── utilities.py
├── comparisons.png
├── config
└── HDTF
│ └── config.yaml
├── data
└── dataloader_HDTF.py
├── demo.py
├── external
└── spectre
│ ├── .gitignore
│ ├── .gitmodules
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── config.py
│ ├── configs
│ └── lipread_config.ini
│ ├── datasets
│ ├── __init__.py
│ ├── build_datasets.py
│ ├── data_utils.py
│ ├── datasets.py
│ └── extra_datasets.py
│ ├── demo.py
│ ├── get_training_data.sh
│ ├── main.py
│ ├── quick_install.sh
│ ├── render.py
│ ├── requirements.txt
│ ├── src
│ ├── __init__.py
│ ├── models
│ │ ├── FLAME.py
│ │ ├── encoders.py
│ │ ├── expression_loss.py
│ │ ├── lbs.py
│ │ └── resnet.py
│ ├── spectre.py
│ ├── trainer_spectre.py
│ └── utils
│ │ ├── lossfunc.py
│ │ ├── renderer.py
│ │ ├── rotation_converter.py
│ │ ├── tensor_cropper.py
│ │ ├── trainer.py
│ │ └── util.py
│ ├── utils
│ ├── __init__.py
│ ├── extract_frames_LRS3.py
│ ├── extract_frames_and_audio.py
│ ├── extract_wavs_LRS3.py
│ ├── lipread_utils.py
│ └── run_av_hubert.py
│ └── visual_mesh.py
├── framework.png
├── losses
└── loss_collections.py
├── models
├── lib
│ ├── base_models.py
│ ├── grl_module.py
│ ├── modules.py
│ └── wav2vec.py
└── network.py
├── requirements.txt
├── tools
└── render_spectre.py
├── train.py
└── utils
└── render_pyrender.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # python file #
2 | ############
3 | *.pyc
4 | __pycache__
5 | # OS generated files #
6 | ######################
7 | .DS_Store
8 | .DS_Store?
9 | ._*
10 | .Spotlight-V100
11 | .Trashes
12 | ehthumbs.db
13 | Thumbs.db
14 | .vscode
15 |
16 | # Packages #
17 | ############
18 | # it's better to unpack these files and commit the raw source
19 | # git has its own built in compression methods
20 | *.7z
21 | *.dmg
22 | *.gz
23 | *.iso
24 | *.jar
25 | *.rar
26 | *.tar
27 | *.zip
28 | *.pth
29 |
30 | demos
31 | pretrained
32 | external/spectre/pretrained
33 | external/spectre/data
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 DoubleXING
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## **Mimic**
2 |
3 | Official PyTorch implementation for the paper:
4 |
5 | > **Mimic: Speaking Style Disentanglement for Speech-Driven 3D Facial Animation**, ***AAAI 2024***.
6 | >
7 | > Hui Fu, Zeqing Wang, Ke Gong, Keze Wang, Tianshui Chen, Haojie Li, Haifeng Zeng, Wenxiong Kang
8 | >
9 | >
10 |
11 |
12 |
13 |
18 |
19 |
13 |
14 |
15 |
19 |
20 |
Our method performs visual-speech aware 3D reconstruction so that speech perception from the original footage is preserved in the reconstructed talking head. On the left we include the word/phrase being said for each example.
22 |
23 | This is the official Pytorch implementation of the paper:
24 |
25 | ```
26 | Visual Speech-Aware Perceptual 3D Facial Expression Reconstruction from Videos
27 | Panagiotis P. Filntisis, George Retsinas, Foivos Paraperas-Papantoniou, Athanasios Katsamanis, Anastasios Roussos, and Petros Maragos
28 | arXiv 2022
29 | ```
30 |
31 |
32 |
33 | ## Installation
34 | Clone the repo and its submodules:
35 | ```bash
36 | git clone --recurse-submodules -j4 https://github.com/filby89/spectre
37 | cd spectre
38 | ```
39 |
40 | You need to have installed a working version of Pytorch with Python 3.6 or higher and Pytorch 3D. You can use the following commands to create a working installation:
41 | ```bash
42 | conda create -n "spectre" python=3.8
43 | conda install -c pytorch pytorch=1.11.0 torchvision torchaudio # you might need to select cudatoolkit version here by adding e.g. cudatoolkit=11.3
44 | conda install -c conda-forge -c fvcore fvcore iopath
45 | conda install pytorch3d -c pytorch3d
46 | pip install -r requirements.txt # install the rest of the requirements
47 | ```
48 |
49 | Installing a working setup of Pytorch3d with Pytorch can be a bit tricky. For development we used Pytorch3d 0.6.1 with Pytorch 1.10.0.
50 |
51 | PyTorch3d 0.6.2 with pytorch 1.11.0 are also compatible.
52 |
53 | Install the face_alignment and face_detection packages:
54 | ```bash
55 | cd external/face_alignment
56 | pip install -e .
57 | cd ../face_detection
58 | git lfs pull
59 | pip install -e .
60 | cd ../..
61 | ```
62 | You may need to install git-lfs to run the above commands. [More details](https://stackoverflow.com/questions/48734119/git-lfs-is-not-a-git-command-unclear)
63 | ```bash
64 | curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
65 | sudo apt-get install git-lfs
66 | ```
67 | Download the FLAME model and the pretrained SPECTRE model:
68 | ```bash
69 | pip install gdown
70 | bash quick_install.sh
71 | ```
72 |
73 | ## Demo
74 | Samples are included in ``samples`` folder. You can run the demo by running
75 |
76 | ```bash
77 | python demo.py --input samples/LRS3/0Fi83BHQsMA_00002.mp4 --audio
78 | ```
79 |
80 | The audio flag extracts audio from the input video and puts it in the output shape video for visualization purposes (ffmpeg is required for video creation).
81 |
82 | ## Training and Testing
83 | In order to train the model you need to download the `trainval` and `test` sets of the [LRS3 dataset](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/lrs3.html). After downloading
84 | the dataset, run the following command to extract frames and audio from the videos (audio is not needed for training but it is nice for visualizing the result):
85 |
86 | ```bash
87 | python utils/extract_frames_and_audio.py --dataset_path ./data/LRS3
88 | ```
89 |
90 | After downloading and preprocessing the dataset, download the rest needed assets:
91 |
92 | ```bash
93 | bash get_training_data.sh
94 | ```
95 |
96 | This command downloads the original [DECA](https://github.com/YadiraF/DECA/) pretrained model,
97 | the ResNet50 emotion recognition model provided by [EMOCA](https://github.com/radekd91/emoca),
98 | the pretrained lipreading model and detected landmarks for the videos of the LRS3 dataset provided by [Visual_Speech_Recognition_for_Multiple_Languages](https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages).
99 |
100 | Finally, you need to create a texture model using the repository [BFM_to_FLAME](https://github.com/TimoBolkart/BFM_to_FLAME#create-texture-model). Due
101 | to licencing reasons we are not allowed to share it to you.
102 |
103 | Now, you can run the following command to train the model:
104 |
105 | ```bash
106 | python main.py --output_dir logs --landmark 50 --relative_landmark 25 --lipread 2 --expression 0.5 --epochs 6 --LRS3_path data/LRS3 --LRS3_landmarks_path data/LRS3_landmarks
107 | ```
108 |
109 | and then test it on the LRS3 dataset test set:
110 |
111 | ```bash
112 | python main.py --test --output_dir logs --model_path logs/model.tar --LRS3_path data/LRS3 --LRS3_landmarks_path data/LRS3_landmarks
113 | ```
114 |
115 | and run lipreading with AV-hubert:
116 |
117 | ```bash
118 | # and run lipreading with our script
119 | python utils/run_av_hubert.py --videos "logs/test_videos_000000/*_mouth.avi --LRS3_path data/LRS3"
120 | ```
121 |
122 |
123 | ## Acknowledgements
124 | This repo is has been heavily based on the original implementation of [DECA](https://github.com/YadiraF/DECA/). We also acknowledge the following
125 | repositories which we have benefited greatly from as well:
126 |
127 | - [EMOCA](https://github.com/radekd91/emoca)
128 | - [face_alignment](https://github.com/hhj1897/face_alignment)
129 | - [face_detection](https://github.com/hhj1897/face_detection)
130 | - [Visual_Speech_Recognition_for_Multiple_Languages](https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages)
131 |
132 | ## Citation
133 | If your research benefits from this repository, consider citing the following:
134 |
135 | ```
136 | @misc{filntisis2022visual,
137 | title = {Visual Speech-Aware Perceptual 3D Facial Expression Reconstruction from Videos},
138 | author = {Filntisis, Panagiotis P. and Retsinas, George and Paraperas-Papantoniou, Foivos and Katsamanis, Athanasios and Roussos, Anastasios and Maragos, Petros},
139 | publisher = {arXiv},
140 | year = {2022},
141 | }
142 | ```
143 |
144 |
145 |
--------------------------------------------------------------------------------
/external/spectre/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/external/spectre/__init__.py
--------------------------------------------------------------------------------
/external/spectre/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | Default config for SPECTRE - adapted from DECA
3 | '''
4 | from yacs.config import CfgNode as CN
5 | import argparse
6 | import yaml
7 | import os
8 |
9 | cfg = CN()
10 |
11 | cfg.project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src', '..'))
12 | cfg.device = 'cuda'
13 | cfg.device_ids = '0'
14 |
15 | cfg.pretrained_modelpath = os.path.join(cfg.project_dir, 'data', 'deca_model.tar')
16 | cfg.output_dir = ''
17 | cfg.rasterizer_type = 'pytorch3d'
18 | # ---------------------------------------------------------------------------- #
19 | # Options for FLAME and from original DECA
20 | # ---------------------------------------------------------------------------- #
21 | cfg.model = CN()
22 | cfg.model.topology_path = os.path.join(cfg.project_dir, 'data' , 'head_template.obj')
23 | # texture data original from http://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_texture_data.zip
24 | cfg.model.dense_template_path = os.path.join(cfg.project_dir, 'data', 'texture_data_256.npy')
25 | cfg.model.fixed_displacement_path = os.path.join(cfg.project_dir, 'data', 'fixed_displacement_256.npy')
26 | cfg.model.flame_model_path = os.path.join(cfg.project_dir, 'data', 'FLAME2020', 'generic_model.pkl')
27 | cfg.model.flame_lmk_embedding_path = os.path.join(cfg.project_dir, 'data', 'landmark_embedding.npy')
28 | cfg.model.face_mask_path = os.path.join(cfg.project_dir, 'data', 'uv_face_mask.png')
29 | cfg.model.face_eye_mask_path = os.path.join(cfg.project_dir, 'data', 'uv_face_eye_mask.png')
30 | cfg.model.mean_tex_path = os.path.join(cfg.project_dir, 'data', 'mean_texture.jpg')
31 | cfg.model.tex_path = os.path.join(cfg.project_dir, 'data', 'FLAME_albedo_from_BFM.npz')
32 | cfg.model.tex_type = 'BFM' # BFM, FLAME, albedoMM
33 | cfg.model.uv_size = 256
34 | cfg.model.param_list = ['shape', 'tex', 'exp', 'pose', 'cam', 'light']
35 | cfg.model.n_shape = 100
36 | cfg.model.n_tex = 50
37 | cfg.model.n_exp = 50
38 | cfg.model.n_cam = 3
39 | cfg.model.n_pose = 6
40 | cfg.model.n_light = 27
41 | cfg.model.jaw_type = 'aa' # default use axis angle, another option: euler. Note that: aa is not stable in the beginning
42 |
43 |
44 |
45 | cfg.model.model_type = "SPECTRE"
46 |
47 | cfg.model.temporal = True
48 |
49 |
50 | # ---------------------------------------------------------------------------- #
51 | # Options for Dataset
52 | # ---------------------------------------------------------------------------- #
53 | cfg.dataset = CN()
54 | cfg.dataset.LRS3_path = "/gpu-data3/filby/LRS3"
55 | cfg.dataset.LRS3_landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/LRS3/LRS3_landmarks"
56 |
57 | cfg.dataset.LRS3_path = "/gpu-data3/filby/LRS3"
58 | cfg.dataset.LRS3_landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/LRS3/LRS3_landmarks"
59 |
60 | cfg.dataset.LRS3_path = "/gpu-data3/filby/LRS3"
61 | cfg.dataset.LRS3_landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/LRS3/LRS3_landmarks"
62 |
63 | cfg.dataset.batch_size = 1
64 | cfg.dataset.K = 20
65 | cfg.dataset.num_workers = 8
66 | cfg.dataset.image_size = 224 # 224/500
67 | cfg.dataset.scale_min = 1.4
68 | cfg.dataset.scale_max = 1.8
69 | cfg.dataset.trans_scale = 0.
70 | cfg.dataset.fps = 25
71 | cfg.dataset.test_datasets = ['LRS3']
72 |
73 | # ---------------------------------------------------------------------------- #
74 | # Options for training
75 | # ---------------------------------------------------------------------------- #
76 | cfg.train = CN()
77 | cfg.train.max_epochs = 6
78 | cfg.train.log_dir = 'logs'
79 | cfg.train.log_steps = 10
80 | cfg.train.vis_dir = 'train_images'
81 | cfg.train.vis_steps = 500
82 | cfg.train.write_summary = True
83 | cfg.train.checkpoint_steps = 10000
84 | cfg.train.val_vis_dir = 'val_images'
85 |
86 | cfg.train.evaluation_steps = 10000
87 |
88 | # ---------------------------------------------------------------------------- #
89 | # Options for Losses
90 | # ---------------------------------------------------------------------------- #
91 | cfg.loss = CN()
92 | cfg.loss.train = CN()
93 |
94 | cfg.model.use_tex = True
95 | cfg.model.regularization_type = 'nonlinear'
96 | cfg.model.backbone = 'mobilenetv2' # perceptual encoder backbone
97 |
98 | cfg.loss.train.landmark = 50
99 | cfg.loss.train.lip_landmarks = 0
100 | cfg.loss.train.relative_landmark = 50# 50
101 | cfg.loss.train.photometric_texture = 0
102 | cfg.loss.train.lipread = 2
103 | cfg.loss.train.jaw_reg = 200
104 | cfg.train.lr = 5e-5
105 | cfg.loss.train.expression = 0.5
106 |
107 | cfg.test_mode = False
108 |
109 | def get_cfg_defaults():
110 | """Get a yacs CfgNode object with default values for my_project."""
111 | # Return a clone so that the defaults will not be altered
112 | # This is for the "local variable" use pattern
113 | return cfg.clone()
114 |
115 | def update_cfg(cfg, cfg_file):
116 | cfg.merge_from_file(cfg_file)
117 | return cfg.clone()
118 |
119 | def parse_args():
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument('--output_dir', type=str, help='output path')
122 | parser.add_argument('--LRS3_path', default=None, type=str, help='path to LRS3 dataset')
123 | parser.add_argument('--LRS3_landmarks_path', default=None, type=str, help='path to LRS3 landmarks')
124 | parser.add_argument('--model_path', default=None, help='path to pretrained model')
125 | parser.add_argument('--batch-size', type=int, default=1, help='the batch size')
126 | parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train for')
127 | parser.add_argument('--K', type=int, default=20, help='length of sampled frame sequence')
128 | parser.add_argument('--lipread', type=float, default=None, help='lipread loss weight')
129 | parser.add_argument('--expression', type=float, default=None, help='expression loss weight')
130 | parser.add_argument('--lr', type=float, default=None, help='learning rate')
131 | parser.add_argument('--landmark', type=float, default=None, help='landmark loss weight')
132 | parser.add_argument('--relative_landmark', type=float, default=None, help='relative landmark loss weight')
133 | parser.add_argument('--backbone', type=str, default='mobilenetv2', choices=['mobilenetv2', 'resnet50'])
134 |
135 | parser.add_argument('--test', action='store_true', help='test mode')
136 | parser.add_argument('--test_datasets', type=str, nargs='+', default=['LRS3'], help='test datasets')
137 |
138 | args = parser.parse_args()
139 |
140 | cfg = get_cfg_defaults()
141 |
142 | cfg.output_dir = args.output_dir
143 |
144 | if args.model_path is not None:
145 | cfg.pretrained_modelpath = args.model_path
146 |
147 | if args.batch_size is not None:
148 | cfg.dataset.batch_size = args.batch_size
149 |
150 | cfg.dataset.K = args.K
151 |
152 | if args.landmark is not None:
153 | cfg.loss.train.landmark = args.landmark
154 |
155 | if args.relative_landmark is not None:
156 | cfg.loss.train.relative_landmark = args.relative_landmark
157 |
158 | if args.lipread is not None:
159 | cfg.loss.train.lipread = args.lipread
160 |
161 | if args.expression is not None:
162 | cfg.loss.train.expression = args.expression
163 |
164 | if args.lr is not None:
165 | cfg.train.lr = args.lr
166 |
167 | if args.epochs is not None:
168 | cfg.train.max_epochs = args.epochs
169 |
170 | if args.LRS3_path is not None:
171 | cfg.dataset.LRS3_path = args.LRS3_path
172 |
173 | if args.LRS3_landmarks_path is not None:
174 | cfg.dataset.LRS3_landmarks_path = args.LRS3_landmarks_path
175 |
176 | cfg.model.backbone = args.backbone
177 |
178 | cfg.test_mode = args.test
179 |
180 | cfg.test_datasets = args.test_datasets
181 |
182 | return cfg
183 |
--------------------------------------------------------------------------------
/external/spectre/configs/lipread_config.ini:
--------------------------------------------------------------------------------
1 | [input]
2 | modality=video
3 | v_fps=25
4 |
5 | [model]
6 | v_fps=25
7 | model_path=data/LRS3_V_WER32.3/model.pth
8 | model_conf=data/LRS3_V_WER32.3/model.json
9 | rnnlm=
10 | rnnlm_conf=
11 |
12 | [decode]
13 | beam_size=1
14 | penalty=0.5
15 | maxlenratio=0.0
16 | minlenratio=0.0
17 | ctc_weight=0.1
18 | lm_weight=0.6
19 |
--------------------------------------------------------------------------------
/external/spectre/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/external/spectre/datasets/__init__.py
--------------------------------------------------------------------------------
/external/spectre/datasets/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def linear_interpolate(landmarks, start_idx, stop_idx):
4 | """linear_interpolate.
5 |
6 | :param landmarks: ndarray, input landmarks to be interpolated.
7 | :param start_idx: int, the start index for linear interpolation.
8 | :param stop_idx: int, the stop for linear interpolation.
9 | """
10 | start_landmarks = landmarks[start_idx]
11 | stop_landmarks = landmarks[stop_idx]
12 | delta = stop_landmarks - start_landmarks
13 | for idx in range(1, stop_idx-start_idx):
14 | landmarks[start_idx+idx] = start_landmarks + idx/float(stop_idx-start_idx) * delta
15 | return landmarks
16 |
17 | def landmarks_interpolate(landmarks):
18 | """landmarks_interpolate.
19 |
20 | :param landmarks: List, the raw landmark (in-place)
21 |
22 | """
23 | valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
24 | if not valid_frames_idx:
25 | return None
26 | for idx in range(1, len(valid_frames_idx)):
27 | if valid_frames_idx[idx] - valid_frames_idx[idx - 1] == 1:
28 | continue
29 | else:
30 | landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx])
31 | valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
32 | # -- Corner case: keep frames at the beginning or at the end failed to be detected.
33 | if valid_frames_idx:
34 | landmarks[:valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0]
35 | landmarks[valid_frames_idx[-1]:] = [landmarks[valid_frames_idx[-1]]] * (len(landmarks) - valid_frames_idx[-1])
36 | valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None]
37 | assert len(valid_frames_idx) == len(landmarks), "not every frame has landmark"
38 | return landmarks
39 |
40 |
41 | def create_LRS3_lists(lrs3_path):
42 | from sklearn.model_selection import train_test_split
43 | import pickle
44 | trainval_folder_list = list(os.listdir(f"{lrs3_path}/trainval"))
45 | train_folder_list, val_folder_list = train_test_split(trainval_folder_list, test_size=0.2, random_state=42)
46 |
47 |
48 | train_list = []
49 | for folder in train_folder_list:
50 | for file in os.listdir(os.path.join(f"{lrs3_path}/trainval", folder)):
51 | if file.endswith(".txt"):
52 | file_without_extension = file.split(".")[0]
53 | train_list.append(f"trainval/{folder}/{file_without_extension}")
54 |
55 |
56 | val_list = []
57 | for folder in val_folder_list:
58 | for file in os.listdir(os.path.join(f"{lrs3_path}/trainval", folder)):
59 | if file.endswith(".txt"):
60 | file_without_extension = file.split(".")[0]
61 | val_list.append(f"trainval/{folder}/{file_without_extension}")
62 |
63 | #
64 | test_folder_list = list(os.listdir(f"{lrs3_path}/test"))
65 | test_list = []
66 | for folder in test_folder_list:
67 | for file in os.listdir(os.path.join(f"{lrs3_path}/test", folder)):
68 | if file.endswith(".txt"):
69 | file_without_extension = file.split(".")[0]
70 | test_list.append(f"test/{folder}/{file_without_extension}")
71 |
72 |
73 | pickle.dump([train_list,val_list,test_list], open(f"data/LRS3_lists.pkl", "wb"))
74 |
--------------------------------------------------------------------------------
/external/spectre/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset
4 | import numpy as np
5 | import cv2
6 | from skimage.transform import estimate_transform, warp
7 | import random
8 | import pickle
9 | from .data_utils import landmarks_interpolate
10 |
11 | class SpectreDataset(Dataset):
12 | def __init__(self, data_list, landmarks_path, cfg, test=False):
13 | self.data_list = data_list
14 | self.image_size = 224
15 | self.K = cfg.K
16 | self.test = test
17 | self.cfg=cfg
18 | self.landmarks_path = landmarks_path
19 |
20 | if not self.test:
21 | self.scale = [1.4, 1.8]
22 | else:
23 | self.scale = 1.6
24 |
25 | def crop_face(self, frame, landmarks, scale=1.0):
26 | left = np.min(landmarks[:, 0])
27 | right = np.max(landmarks[:, 0])
28 | top = np.min(landmarks[:, 1])
29 | bottom = np.max(landmarks[:, 1])
30 |
31 | h, w, _ = frame.shape
32 | old_size = (right - left + bottom - top) / 2
33 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) # + old_size*0.1])
34 |
35 | size = int(old_size * scale)
36 |
37 | # crop image
38 | src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2],
39 | [center[0] + size / 2, center[1] - size / 2]])
40 | DST_PTS = np.array([[0, 0], [0, self.image_size - 1], [self.image_size - 1, 0]])
41 | tform = estimate_transform('similarity', src_pts, DST_PTS)
42 |
43 | return tform
44 |
45 | def __len__(self):
46 | return len(self.data_list)
47 |
48 | def __getitem__(self, index):
49 | images_list = []; kpt_list = [];
50 |
51 | sample = self.data_list[index]
52 |
53 | landmarks_filename = os.path.join(self.landmarks_path, sample[0]+".pkl")
54 | folder_path = os.path.join(self.cfg.LRS3_path, sample[0])
55 |
56 | with open(landmarks_filename, "rb") as pkl_file:
57 | landmarks = pickle.load(pkl_file)
58 | preprocessed_landmarks = landmarks_interpolate(landmarks)
59 | if preprocessed_landmarks is None:
60 | return None
61 |
62 | if self.test:
63 | frame_indices = list(range(len(landmarks)))
64 | else:
65 | if len(landmarks) < self.K:
66 | start_idx = 0
67 | end_idx = len(landmarks)
68 | else:
69 | start_idx = random.randint(0, len(landmarks) - self.K)
70 | end_idx = start_idx + self.K
71 |
72 | frame_indices = list(range(start_idx,end_idx))
73 |
74 | if isinstance(self.scale, list):
75 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0]
76 | else:
77 | scale = self.scale
78 |
79 | for frame_idx in frame_indices:
80 | if "LRS3" in self.landmarks_path:
81 | frame = cv2.imread(os.path.join(folder_path,"%06d.jpg"%(frame_idx)))
82 | folder_path = os.path.join(self.cfg.LRS3_path, sample[0])
83 | wav = folder_path + ".wav"
84 | else: # during test mode for other datasets
85 | if 'MEAD' in self.landmarks_path:
86 | folder_path = os.path.join("/gpu-data3/filby/MEAD/rendered/train/MEAD/images", sample[0])
87 | frame = cv2.imread(os.path.join(folder_path,"%06d.png"%(frame_idx)))
88 | wav = folder_path.replace("images","wavs") + ".wav"
89 | else:
90 | folder_path = os.path.join("/gpu-data3/filby/EAVTTS/TCDTIMIT_preprocessed/images", sample[0])
91 | frame = cv2.imread(os.path.join(folder_path,"%06d.png"%(frame_idx)))
92 | wav = folder_path.replace("images","wavs") + ".wav"
93 |
94 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
95 | kpt = preprocessed_landmarks[frame_idx]
96 | tform = self.crop_face(frame,kpt,scale)
97 | cropped_image = warp(frame, tform.inverse, output_shape=(self.image_size, self.image_size))
98 |
99 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T
100 |
101 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1
102 |
103 | images_list.append(cropped_image.transpose(2,0,1))
104 | kpt_list.append(cropped_kpt)
105 |
106 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3
107 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3
108 |
109 | # text = open(folder_path+".txt").readlines()[0].replace("Text:","").strip()
110 | text = sample[1] # open(folder_path+".txt").readlines()[0].replace("Text:","").strip()
111 |
112 | data_dict = {
113 | 'image': images_array,
114 | 'landmark': kpt_array,
115 | 'vid_name': sample[0],
116 | 'wav_path': wav, # this is only used for evaluation - you can remove this key from the dictionary if you don't need it
117 | 'text': text, # this is only used for evaluation - you can remove this key from the dictionary if you don't need it
118 | }
119 |
120 | return data_dict
121 |
122 |
123 | def get_datasets_LRS3(config=None):
124 | if not os.path.exists('data/LRS3_lists.pkl'):
125 | print('Creating train, validation, and test lists for LRS3... (This only happens once)')
126 |
127 | from .data_utils import create_LRS3_lists
128 | create_LRS3_lists(config.LRS3_path)
129 |
130 |
131 | lists = pickle.load(open("data/LRS3_lists.pkl", "rb"))
132 | train_list = lists[0]
133 | val_list = lists[1]
134 | test_list = lists[2]
135 | landmarks_path = config.LRS3_landmarks_path
136 | return SpectreDataset(train_list, landmarks_path, cfg=config), SpectreDataset(val_list, landmarks_path, cfg=config), SpectreDataset(test_list, landmarks_path,
137 | cfg=config,
138 | test=True)
139 |
--------------------------------------------------------------------------------
/external/spectre/datasets/extra_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .datasets import SpectreDataset
3 |
4 |
5 | def get_datasets_MEAD(config=None):
6 | import pandas as pd
7 | questionnaire_list = pd.read_csv("../utils/MEAD_test_set_final.csv")
8 | test_list = [(x[0],x[1]) for x in zip(questionnaire_list.name,questionnaire_list.text)]
9 | landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/MEAD_images_25fps"
10 |
11 | return None, None, SpectreDataset(test_list, landmarks_path, cfg=config, test=True)
12 |
13 |
14 | def get_datasets_TCDTIMIT(config=None):
15 | tcd_root = "/gpu-data3/filby/EAVTTS"
16 |
17 | landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/TCDTIMIT_images_25fps"
18 |
19 | root = f"{tcd_root}/TCDTIMIT_preprocessed/TCDSpkrIndepTrainSet.scp"
20 | files = open(root).readlines()
21 | train_list = []
22 | for file in files:
23 | f = file.strip().split("/")
24 | new_name = f"{f[0]}_{f[-1]}"
25 |
26 | ff = "/".join([f[0],f[1],f[2]])
27 |
28 | text = open(os.path.join(f"{tcd_root}/TCDTIMITprocessing/downloadTCDTIMIT/volunteers",ff,f[-1].upper().replace(".MP4",".txt"))).readlines()
29 |
30 | text = " ".join([x.split()[2].strip() for x in text])
31 |
32 | train_list.append((new_name.split(".")[0],text))
33 |
34 |
35 | root = f"{tcd_root}/TCDTIMIT_preprocessed/TCDSpkrIndepTestSet.scp"
36 | files = open(root).readlines()
37 | test_list = []
38 | for file in files:
39 | f = file.strip().split("/")
40 | new_name = f"{f[0]}_{f[-1]}"
41 |
42 | ff = "/".join([f[0],f[1],f[2]])
43 |
44 | text = open(os.path.join(f"{tcd_root}/TCDTIMITprocessing/downloadTCDTIMIT/volunteers",ff,f[-1].upper().replace(".MP4",".txt"))).readlines()
45 |
46 | text = " ".join([x.split()[2].strip() for x in text])
47 |
48 | test_list.append((new_name.split(".")[0],text.upper()))
49 |
50 |
51 | return SpectreDataset(train_list, landmarks_path, cfg=config), SpectreDataset(test_list, landmarks_path, cfg=config), SpectreDataset(test_list, landmarks_path, cfg=config, test=True)
--------------------------------------------------------------------------------
/external/spectre/demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os, sys
4 | import argparse
5 | # sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6 | import os, sys
7 | import torch
8 | import numpy as np
9 | import cv2
10 | from skimage.transform import estimate_transform, warp, resize, rescale
11 | import scipy.io
12 | import collections
13 | from tqdm import tqdm
14 | from datasets.data_utils import landmarks_interpolate
15 | from src.spectre import SPECTRE
16 | from config import cfg as spectre_cfg
17 | from src.utils.util import tensor2video
18 | import torchvision
19 |
20 | def extract_frames(video_path, detect_landmarks=True):
21 | videofolder = os.path.splitext(video_path)[0]
22 | os.makedirs(videofolder, exist_ok=True)
23 | vidcap = cv2.VideoCapture(video_path)
24 |
25 | if detect_landmarks is True:
26 | from external.Visual_Speech_Recognition_for_Multiple_Languages.tracker.face_tracker import FaceTracker
27 | from external.Visual_Speech_Recognition_for_Multiple_Languages.tracker.utils import get_landmarks
28 | face_tracker = FaceTracker()
29 |
30 | imagepath_list = []
31 | count = 0
32 |
33 | face_info = collections.defaultdict(list)
34 |
35 | fps = vidcap.get(cv2.CAP_PROP_FPS)
36 |
37 | with tqdm(total=int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar:
38 | while True:
39 | success, image = vidcap.read()
40 | if not success:
41 | break
42 |
43 | if detect_landmarks is True:
44 | detected_faces = face_tracker.face_detector(image, rgb=False)
45 | # -- face alignment
46 | landmarks, scores = face_tracker.landmark_detector(image, detected_faces, rgb=False)
47 | face_info['bbox'].append(detected_faces)
48 | face_info['landmarks'].append(landmarks)
49 | face_info['landmarks_scores'].append(scores)
50 |
51 | imagepath = os.path.join(videofolder, f'{count:06d}.jpg')
52 | cv2.imwrite(imagepath, image) # save frame as JPEG file
53 | count += 1
54 | imagepath_list.append(imagepath)
55 | pbar.update(1)
56 | pbar.set_description("Preprocessing frame %d" % count)
57 |
58 | landmarks = get_landmarks(face_info)
59 | print('video frames are stored in {}'.format(videofolder))
60 | return imagepath_list, landmarks, videofolder, fps
61 |
62 |
63 |
64 | def crop_face(frame, landmarks, scale=1.0):
65 | image_size = 224
66 | left = np.min(landmarks[:, 0])
67 | right = np.max(landmarks[:, 0])
68 | top = np.min(landmarks[:, 1])
69 | bottom = np.max(landmarks[:, 1])
70 |
71 | h, w, _ = frame.shape
72 | old_size = (right - left + bottom - top) / 2
73 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
74 |
75 | size = int(old_size * scale)
76 |
77 | src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2],
78 | [center[0] + size / 2, center[1] - size / 2]])
79 | DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
80 | tform = estimate_transform('similarity', src_pts, DST_PTS)
81 |
82 | return tform
83 |
84 |
85 |
86 | def main(args):
87 | args.crop_face = True
88 | spectre_cfg.pretrained_modelpath = "pretrained/spectre_model.tar"
89 | spectre_cfg.model.use_tex = False
90 |
91 | spectre = SPECTRE(spectre_cfg, args.device)
92 | spectre.eval()
93 |
94 | image_paths, landmarks, videofolder, fps = extract_frames(args.input, detect_landmarks=args.crop_face)
95 | if args.crop_face:
96 | landmarks = landmarks_interpolate(landmarks)
97 | if landmarks is None:
98 | print('No faces detected in input {}'.format(args.input))
99 |
100 |
101 | original_video_length = len(image_paths)
102 | """ SPECTRE uses a temporal convolution of size 5.
103 | Thus, in order to predict the parameters for a contiguous video with need to
104 | process the video in chunks of overlap 2, dropping values which were computed from the
105 | temporal kernel which uses pad 'same'. For the start and end of the video we
106 | pad using the first and last frame of the video.
107 | e.g., consider a video of size 48 frames and we want to predict it in chunks of 20 frames
108 | (due to memory limitations). We first pad the video two frames at the start and end using
109 | the first and last frames correspondingly, making the video 52 frames length.
110 |
111 | Then we process independently the following chunks:
112 | [[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
113 | [16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35]
114 | [32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51]]
115 |
116 | In the first chunk, after computing the 3DMM params we drop 0,1 and 18,19, since they were computed
117 | from the temporal kernel with padding (we followed the same procedure in training and computed loss
118 | only from valid outputs of the temporal kernel) In the second chunk, we drop 16,17 and 34,35, and in
119 | the last chunk we drop 32,33 and 50,51. As a result we get:
120 | [2..17], [18..33], [34..49] (end included) which correspond to all frames of the original video
121 | (removing the initial padding).
122 | """
123 |
124 | # pad
125 | image_paths.insert(0,image_paths[0])
126 | image_paths.insert(0,image_paths[0])
127 | image_paths.append(image_paths[-1])
128 | image_paths.append(image_paths[-1])
129 |
130 | landmarks.insert(0,landmarks[0])
131 | landmarks.insert(0,landmarks[0])
132 | landmarks.append(landmarks[-1])
133 | landmarks.append(landmarks[-1])
134 |
135 | landmarks = np.array(landmarks)
136 |
137 | L = 50 # chunk size
138 |
139 | # create lists of overlapping indices
140 | indices = list(range(len(image_paths)))
141 | overlapping_indices = [indices[i: i + L] for i in range(0, len(indices), L-4)]
142 |
143 | if len(overlapping_indices[-1]) < 5:
144 | # if the last chunk has less than 5 frames, pad it with the semilast frame
145 | overlapping_indices[-2] = overlapping_indices[-2] + overlapping_indices[-1]
146 | overlapping_indices[-2] = np.unique(overlapping_indices[-2]).tolist()
147 | overlapping_indices = overlapping_indices[:-1]
148 |
149 | overlapping_indices = np.array(overlapping_indices)
150 |
151 | image_paths = np.array(image_paths) # do this to index with multiple indices
152 | all_shape_images = []
153 | all_images = []
154 |
155 | with torch.no_grad():
156 | for chunk_id in range(len(overlapping_indices)):
157 | print('Processing frames {} to {}'.format(overlapping_indices[chunk_id][0], overlapping_indices[chunk_id][-1]))
158 | image_paths_chunk = image_paths[overlapping_indices[chunk_id]]
159 |
160 | landmarks_chunk = landmarks[overlapping_indices[chunk_id]] if args.crop_face else None
161 |
162 | images_list = []
163 |
164 | """ load each image and crop it around the face if necessary """
165 | for j in range(len(image_paths_chunk)):
166 | frame = cv2.imread(image_paths_chunk[j])
167 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
168 | kpt = landmarks_chunk[j]
169 |
170 | tform = crop_face(frame,kpt,scale=1.6)
171 | cropped_image = warp(frame, tform.inverse, output_shape=(224, 224))
172 |
173 | images_list.append(cropped_image.transpose(2,0,1))
174 |
175 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32).to(args.device) #K,224,224,3
176 |
177 | codedict, initial_deca_exp, initial_deca_jaw = spectre.encode(images_array)
178 | codedict['exp'] = codedict['exp'] + initial_deca_exp
179 | codedict['pose'][..., 3:] = codedict['pose'][..., 3:] + initial_deca_jaw
180 |
181 | for key in codedict.keys():
182 | """ filter out invalid indices - see explanation at the top of the function """
183 |
184 | if chunk_id == 0 and chunk_id == len(overlapping_indices) - 1:
185 | pass
186 | elif chunk_id == 0:
187 | codedict[key] = codedict[key][:-2]
188 | elif chunk_id == len(overlapping_indices) - 1:
189 | codedict[key] = codedict[key][2:]
190 | else:
191 | codedict[key] = codedict[key][2:-2]
192 |
193 | opdict, visdict = spectre.decode(codedict, rendering=True, vis_lmk=False, return_vis=True)
194 | all_shape_images.append(visdict['shape_images'].detach().cpu())
195 | all_images.append(codedict['images'].detach().cpu())
196 |
197 | vid_shape = tensor2video(torch.cat(all_shape_images, dim=0))[2:-2] # remove padding
198 | vid_orig = tensor2video(torch.cat(all_images, dim=0))[2:-2] # remove padding
199 | grid_vid = np.concatenate((vid_shape, vid_orig), axis=2)
200 |
201 | assert original_video_length == len(vid_shape)
202 |
203 | if args.audio:
204 | import librosa
205 | wav, sr = librosa.load(args.input)
206 | wav = torch.FloatTensor(wav)
207 | if len(wav.shape) == 1:
208 | wav = wav.unsqueeze(0)
209 |
210 | torchvision.io.write_video(videofolder+"_shape.mp4", vid_shape, fps=fps, audio_codec='aac', audio_array=wav, audio_fps=sr)
211 | torchvision.io.write_video(videofolder+"_grid.mp4", grid_vid, fps=fps,
212 | audio_codec='aac', audio_array=wav, audio_fps=sr)
213 |
214 | else:
215 | torchvision.io.write_video(videofolder+"_shape.mp4", vid_shape, fps=fps)
216 | torchvision.io.write_video(videofolder+"_grid.mp4", grid_vid, fps=fps)
217 |
218 |
219 | if __name__ == '__main__':
220 | parser = argparse.ArgumentParser(description='DECA: Detailed Expression Capture and Animation')
221 |
222 | parser.add_argument('-i', '--input', default='examples', type=str,
223 | help='path to the test data, can be image folder, image path, image list, video')
224 | # parser.add_argument('-o', '--outpath', default='examples/results', type=str,
225 | # help='path to the output directory, where results(obj, txt files) will be stored.')
226 | parser.add_argument('--device', default='cuda', type=str,
227 | help='set device, cpu for using cpu')
228 | parser.add_argument('--audio', action='store_true',
229 | help='extract audio from the original video and add it to the output video')
230 |
231 | main(parser.parse_args())
--------------------------------------------------------------------------------
/external/spectre/get_training_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # file adapted from MICA https://raw.githubusercontent.com/Zielon/MICA
3 | #
4 | echo -e "\nDownloading deca_model..."
5 | #
6 | FILEID=1rp8kdyLPvErw2dTmqtjISRVvQLj6Yzje
7 | FILENAME=./data/deca_model.tar
8 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='${FILEID} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O $FILENAME && rm -rf /tmp/cookies.txt
9 |
10 |
11 | echo "To download the Emotion Recognition from EMOCA which is used from SPECTRE for expression loss, please register at:",
12 | echo -e '\e]8;;https://emoca.is.tue.mpg.de\ahttps://emoca.is.tue.mpg.de\e]8;;\a'
13 | while true; do
14 | read -p "I have registered and agreed to the license terms at https://emoca.is.tue.mpg.de? (y/n)" yn
15 | case $yn in
16 | [Yy]* ) break;;
17 | [Nn]* ) exit;;
18 | * ) echo "Please answer yes or no.";;
19 | esac
20 | done
21 |
22 | wget https://download.is.tue.mpg.de/emoca/assets/EmotionRecognition/image_based_networks/ResNet50.zip -O ResNet50.zip
23 | unzip ResNet50.zip -d data/
24 | rm ResNet50.zip
25 |
26 | echo -e "\nDownloading lipreading pretrained model..."
27 |
28 | FILEID=1yHd4QwC7K_9Ro2OM_hC7pKUT2URPvm_f
29 | FILENAME=LRS3_V_WER32.3.zip
30 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='${FILEID} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O $FILENAME && rm -rf /tmp/cookies.txt
31 | unzip $FILENAME -d data/
32 | rm LRS3_V_WER32.3.zip
33 |
34 | echo -e "\nDownloading landmarks for LRS3 dataset ..."
35 |
36 | gdown --id 1QRdOgeHvmKK8t4hsceFVf_BSpidQfUyW
37 | unzip LRS3_landmarks.zip -d data/
38 | rm LRS3_landmarks.zip
39 |
40 |
41 |
42 | echo -e "\nInstallation has finished!"
43 |
--------------------------------------------------------------------------------
/external/spectre/main.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import yaml
4 | import torch.backends.cudnn as cudnn
5 | import torch
6 | import shutil
7 |
8 |
9 | def main(cfg):
10 | # creat folders
11 | os.makedirs(os.path.join(cfg.output_dir, cfg.train.log_dir), exist_ok=True)
12 |
13 | if cfg.test_mode is False:
14 | os.makedirs(os.path.join(cfg.output_dir, cfg.train.vis_dir), exist_ok=True)
15 | os.makedirs(os.path.join(cfg.output_dir, cfg.train.val_vis_dir), exist_ok=True)
16 | with open(os.path.join(cfg.output_dir, 'full_config.yaml'), 'w') as f:
17 | yaml.dump(cfg, f, default_flow_style=False)
18 |
19 | # cudnn related setting
20 | cudnn.benchmark = True
21 | torch.backends.cudnn.deterministic = False
22 | torch.backends.cudnn.enabled = True
23 |
24 | # start training
25 | from src.trainer_spectre import Trainer
26 | from src.spectre import SPECTRE
27 | spectre = SPECTRE(cfg)
28 |
29 | trainer = Trainer(model=spectre, config=cfg)
30 |
31 | if cfg.test_mode:
32 | trainer.prepare_data()
33 | trainer.evaluate(trainer.test_datasets)
34 | else:
35 | trainer.fit()
36 |
37 | if __name__ == '__main__':
38 | from config import parse_args
39 | cfg = parse_args()
40 | cfg.exp_name = cfg.output_dir
41 |
42 | main(cfg)
43 |
--------------------------------------------------------------------------------
/external/spectre/quick_install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # file adapted from MICA https://github.com/Zielon/MICA/
3 |
4 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
5 |
6 | # username and password input
7 | echo -e "\nIf you do not have an account you can register at https://flame.is.tue.mpg.de/ following the installation instruction."
8 | read -p "Username (FLAME):" username
9 | read -p "Password (FLAME):" password
10 | username=$(urle $username)
11 | password=$(urle $password)
12 |
13 | echo -e "\nDownloading FLAME..."
14 | mkdir -p data/FLAME2020/
15 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=flame&sfile=FLAME2020.zip&resume=1' -O './FLAME2020.zip' --no-check-certificate --continue
16 | unzip FLAME2020.zip -d data/FLAME2020/
17 | rm -rf FLAME2020.zip
18 |
19 | echo -e "\nDownload pretrained SPECTRE model..."
20 | gdown --id 1vmWX6QmXGPnXTXWFgj67oHzOoOmxBh6B
21 | mkdir -p pretrained/
22 | mv spectre_model.tar pretrained/
23 |
24 |
25 |
--------------------------------------------------------------------------------
/external/spectre/render.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import random
5 | import math
6 | import numpy as np
7 | import cv2
8 | import torch
9 | import torchvision
10 | from tqdm import tqdm
11 | from .config import cfg as spectre_cfg
12 | from .src.spectre import SPECTRE
13 |
14 |
15 | class Render():
16 | def __init__(self, device='cuda:0') -> None:
17 | # model
18 | self.device = device
19 | spectre_cfg.pretrained_modelpath = "external/spectre/pretrained/HDTF_pretrained/00032000.tar"
20 | spectre_cfg.model.use_tex = False
21 | self.spectre = SPECTRE(spectre_cfg, device=self.device)
22 | self.spectre.eval()
23 |
24 | def forward(self, exp_out, exp, data_info):
25 | 'input: expression coefficients'
26 | 'output: mesh'
27 | n = self.cfg['trainer']['visual_images']
28 | if self.cfg['datasets']['dataset']=='mead':
29 | # template
30 | codedict = {}
31 | codedict['pose'] = torch.zeros((n*3, 6), dtype=torch.float).to(self.device)
32 | codedict['exp'] = torch.zeros((n*3, 50), dtype=torch.float).to(self.device)
33 | codedict['shape'] = torch.zeros((n*3, 100), dtype=torch.float).to(self.device)
34 | codedict['tex'] = torch.zeros((n*3, 50), dtype=torch.float).to(self.device)
35 | codedict['cam'] = torch.zeros((n*3, 3), dtype=torch.float).to(self.device)
36 | self.codedict = codedict
37 | # true coefficients
38 | coefficient_path = os.path.join(data_info, 'crop_head_info.npy')
39 | coefficient_info = np.load(coefficient_path, allow_pickle=True).item()['face3d_encode']
40 | coefficients = get_coefficients(coefficient_info)
41 | for key in coefficients:
42 | coefficients[key] = torch.FloatTensor(coefficients[key]).to(self.device)
43 | start_vis = random.randint(0,exp.shape[1]-1-n) # 起始帧
44 |
45 | self.codedict['exp'][0:n] = exp_out[0, start_vis: start_vis+n,:-3] # 生成的参数在平均脸上
46 | self.codedict['exp'][n:2*n] = exp[0, start_vis: start_vis+n,:-3] # ground_truth参数在平均脸上
47 | self.codedict['exp'][2*n:3*n] = coefficients['exp'][start_vis: start_vis+n,:] # ground_truth在原脸上
48 |
49 | self.codedict['pose'][0:n, 3:] = exp_out[0, start_vis: start_vis+n,-3:] # jaw pose
50 | self.codedict['pose'][n:2*n, 3:] = exp[0, start_vis: start_vis+n,-3:]
51 |
52 | self.codedict['cam'][0:n] = coefficients['cam'][start_vis: start_vis+n, :] # 取n帧的cam
53 | self.codedict['cam'][n:2*n] = coefficients['cam'][start_vis: start_vis+n, :]
54 | self.codedict['cam'][2*n:3*n] = coefficients['cam'][start_vis: start_vis+n, :]
55 |
56 | self.codedict['pose'][2*n:3*n] = coefficients['pose'][start_vis: start_vis+n, :] # 取n帧的pose
57 | self.codedict['shape'][2*n:3*n] = coefficients['shape'][start_vis: start_vis+n, :] # # 取n帧的shape
58 |
59 | elif self.cfg['datasets']['dataset']=='mote':
60 | # template
61 | codedict = {}
62 | codedict['pose'] = torch.zeros((n*2, 6), dtype=torch.float).to(self.device)
63 | codedict['exp'] = torch.zeros((n*2, 50), dtype=torch.float).to(self.device)
64 | codedict['shape'] = torch.zeros((n*2, 100), dtype=torch.float).to(self.device)
65 | codedict['tex'] = torch.zeros((n*2, 50), dtype=torch.float).to(self.device)
66 | codedict['cam'] = torch.zeros((n*2, 3), dtype=torch.float).to(self.device)
67 | self.codedict = codedict
68 | # true coefficients
69 | # coefficient_path = os.path.join(self.cfg['datasets']['data_root'], data_info[0][0], data_info[1][0], 'train1_all.npz')
70 | # coefficient_info = np.load(coefficient_path, allow_pickle=True)['face'][-1*self.cfg['datasets']['eval_frames']:, :]
71 | # coefficients = get_coefficients(coefficient_info)
72 | # for key in coefficients:
73 | # coefficients[key] = torch.FloatTensor(coefficients[key]).to(self.device)
74 | start_vis = random.randint(0,exp.shape[1]-1-n) # 起始帧
75 |
76 | self.codedict['exp'][0:n] = exp_out[0, start_vis: start_vis+n,:-3] # 生成的参数在平均脸上
77 | self.codedict['exp'][n:2*n] = exp[0, start_vis: start_vis+n,:-3] # ground_truth参数在平均脸上
78 |
79 | self.codedict['pose'][0:n, 3:] = exp_out[0, start_vis: start_vis+n,-3:] # jaw pose
80 | self.codedict['pose'][n:2*n, 3:] = exp[0, start_vis: start_vis+n,-3:]
81 |
82 | cam = torch.tensor([8.8093824, 0.00314824, 0.043486204]).unsqueeze(0).repeat(n, 1) # cam
83 | self.codedict['cam'][0:n] = cam
84 | self.codedict['cam'][n:2*n] = cam
85 |
86 | opdict = self.spectre.decode(self.codedict, rendering=True, vis_lmk=False, return_vis=False)
87 | # rendered_images = torchvision.utils.make_grid(opdict['rendered_images'].detach().cpu(), nrow=n)
88 | return opdict['rendered_images']
89 |
90 | def infer(self, exp, exp_gt=None, render_batch=100):
91 | 'input: expression coefficients'
92 | 'output: mesh'
93 | n = exp.shape[1]
94 | coefficients = {}
95 | coefficients['pose'] = torch.zeros((n, 6), dtype=torch.float).to(exp.device)
96 | coefficients['exp'] = torch.zeros((n, 50), dtype=torch.float).to(exp.device)
97 | coefficients['shape'] = torch.zeros((n, 100), dtype=torch.float).to(exp.device)
98 | coefficients['tex'] = torch.zeros((n, 50), dtype=torch.float).to(exp.device)
99 | coefficients['cam'] = torch.zeros((n, 3), dtype=torch.float).to(exp.device)
100 | coefficients_pred = copy.deepcopy(coefficients)
101 | # cam = torch.tensor([8.8093824, 0.00314824, 0.043486204]).unsqueeze(0).repeat(n, 1).to(exp.device) # cam
102 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(exp.device) # cam
103 | for key in coefficients:
104 | # coefficients[key] = torch.FloatTensor(torch.from_numpy(coefficients[key])).to(exp.device)
105 | # coefficients_pred[key] = torch.FloatTensor(torch.from_numpy(coefficients_pred[key])).to(exp.device)
106 | if key == 'exp':
107 | if exp_gt is not None:
108 | coefficients[key] = exp_gt[0][:, :-3]
109 | coefficients_pred[key] = exp[0][:, :-3]
110 | elif key == 'pose':
111 | if exp_gt is not None:
112 | coefficients[key][:, -3:] = exp_gt[0][:, -3:]
113 | coefficients_pred[key][:, -3:] = exp[0][:, -3:]
114 | elif key == 'cam':
115 | coefficients[key] = cam[:, :]
116 | coefficients_pred[key] = cam[:, :]
117 | n_batch = int(math.ceil(n/render_batch))
118 | rendered_images, rendered_images_pred = [], []
119 | for i in range(n_batch):
120 | coefficients_render, coefficients_pred_render = {}, {}
121 | for k in coefficients:
122 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n)
123 | coefficients_render[k] = coefficients[k][start_f: end_f]
124 | coefficients_pred_render[k] = coefficients_pred[k][start_f: end_f]
125 |
126 | if exp_gt is not None:
127 | opdict = self.spectre.decode(coefficients_render, rendering=True, vis_lmk=False, return_vis=False)
128 | rendered_images.append(opdict['rendered_images'].detach().cpu())
129 | opdict_pred = self.spectre.decode(coefficients_pred_render, rendering=True, vis_lmk=False, return_vis=False)
130 | rendered_images_pred.append(opdict_pred['rendered_images'].detach().cpu())
131 | if exp_gt is not None:
132 | rendered_images_cat = torch.cat(rendered_images, dim=0)
133 | else:
134 | rendered_images_cat = None
135 | rendered_images_pred_cat = torch.cat(rendered_images_pred, dim=0)
136 |
137 | return rendered_images_cat, rendered_images_pred_cat
138 | # opdict = self.spectre.decode(coefficients, rendering=True, vis_lmk=False, return_vis=False)
139 | # opdict_pred = self.spectre.decode(coefficients_pred, rendering=True, vis_lmk=False, return_vis=False)
140 | # return opdict['rendered_images'], opdict_pred['rendered_images']
141 |
142 | def exp2mesh(self, coefficients_info, pose0=True, render_batch=100):
143 | n = coefficients_info.shape[0]
144 | if coefficients_info.shape[-1] == 53:
145 | coefficients = {}
146 | coefficients['pose'] = torch.zeros((n, 6), dtype=torch.float).to(coefficients_info.device)
147 | coefficients['exp'] = torch.zeros((n, 50), dtype=torch.float).to(coefficients_info.device)
148 | coefficients['shape'] = torch.zeros((n, 100), dtype=torch.float).to(coefficients_info.device)
149 | coefficients['tex'] = torch.zeros((n, 50), dtype=torch.float).to(coefficients_info.device)
150 | coefficients['cam'] = torch.zeros((n, 3), dtype=torch.float).to(coefficients_info.device)
151 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(coefficients_info.device) # cam
152 | for key in coefficients:
153 | # coefficients[key] = torch.FloatTensor(torch.from_numpy(coefficients[key])).to(coefficients_info.device)
154 | # coefficients_pred[key] = torch.FloatTensor(torch.from_numpy(coefficients_pred[key])).to(coefficients_info.device)
155 | if key == 'exp':
156 | coefficients[key] = coefficients_info[:, 3:]
157 | elif key == 'pose':
158 | coefficients[key][:, -3:] = coefficients_info[:, :3]
159 | elif key == 'cam':
160 | coefficients[key] = cam[:, :]
161 | elif coefficients_info.shape[-1] == 209 or coefficients_info.shape[-1] == 213 or coefficients_info.shape[-1] == 236:
162 | coefficients = get_coefficients(coefficients_info)
163 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(coefficients_info.device) # cam
164 | for key in coefficients:
165 | coefficients[key] = torch.FloatTensor(coefficients[key]).to(coefficients_info.device)
166 | if pose0:
167 | if key == 'pose':
168 | coefficients[key][:, :3] = torch.zeros_like(coefficients[key][:, :3])
169 | elif key == 'shape' or key == 'tex':
170 | coefficients[key] = torch.zeros_like(coefficients[key])
171 | elif key == 'cam':
172 | coefficients[key] = cam
173 |
174 | n_batch = int(math.ceil(n/render_batch))
175 | rendered_images = []
176 | vertices = []
177 | for i in tqdm(range(n_batch)):
178 | coefficients_batch = {}
179 | for k in coefficients:
180 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n)
181 | coefficients_batch[k] = coefficients[k][start_f: end_f]
182 | opdict = self.spectre.decode(coefficients_batch, rendering=True, vis_lmk=False, return_vis=False)
183 | rendered_images.append(opdict['rendered_images'].detach().cpu())
184 | vertices.append(opdict['verts'].detach().cpu())
185 | rendered_images_cat = torch.cat(rendered_images, dim=0)
186 | vertices_cat = torch.cat(vertices, dim=0)
187 | return rendered_images_cat, vertices_cat
188 |
189 | def coff2mesh(self, coeff, pose0=True, render_batch=100):
190 | n = coeff.shape[0]
191 |
192 | assert coeff.shape[-1] == 209 or coeff.shape[-1] == 213 or coeff.shape[-1] == 236
193 | coefficients = get_coefficients(coeff)
194 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(coeff.device) # cam
195 | coefficients['cam'] = cam
196 | if pose0:
197 | coefficients['pose'][:, :3] = torch.zeros_like(coefficients['pose'][:, :3])
198 |
199 | n_batch = int(math.ceil(n/render_batch))
200 | rendered_images = []
201 | vertices = []
202 | for i in tqdm(range(n_batch)):
203 | coefficients_batch = {}
204 | for k in coefficients:
205 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n)
206 | coefficients_batch[k] = coefficients[k][start_f: end_f]
207 | opdict = self.spectre.decode(coefficients_batch, rendering=True, vis_lmk=False, return_vis=False)
208 | rendered_images.append(opdict['rendered_images'].detach().cpu())
209 | vertices.append(opdict['verts'].detach().cpu())
210 | rendered_images_cat = torch.cat(rendered_images, dim=0)
211 | vertices_cat = torch.cat(vertices, dim=0)
212 | return rendered_images_cat, vertices_cat
213 |
214 | def coff2mesh_rawcam(self, coeff, pose0=True, render_batch=100):
215 | n = coeff.shape[0]
216 |
217 | if coeff.shape[-1] == 209 or coeff.shape[-1] == 213 or coeff.shape[-1] == 236:
218 | coefficients = get_coefficients(coeff)
219 | if pose0:
220 | for key in coefficients:
221 | if key == 'pose':
222 | coefficients[key][:, :3] = torch.zeros_like(coefficients[key][:, :3])
223 | # cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(coeff.device) # cam
224 | # coefficients['cam'] = cam
225 | # for key in coefficients:
226 | # coefficients[key] = torch.FloatTensor(coefficients[key]).to(coefficients_info.device)
227 | # if pose0:
228 | # if key == 'pose':
229 | # coefficients[key][:, :3] = torch.zeros_like(coefficients[key][:, :3])
230 | # elif key == 'shape' or key == 'tex':
231 | # coefficients[key] = torch.zeros_like(coefficients[key])
232 | # elif key == 'cam':
233 | # coefficients[key] = cam
234 |
235 | n_batch = int(math.ceil(n/render_batch))
236 | rendered_images = []
237 | vertices = []
238 | for i in tqdm(range(n_batch)):
239 | coefficients_batch = {}
240 | for k in coefficients:
241 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n)
242 | coefficients_batch[k] = coefficients[k][start_f: end_f]
243 | opdict = self.spectre.decode(coefficients_batch, rendering=True, vis_lmk=False, return_vis=False)
244 | rendered_images.append(opdict['rendered_images'].detach().cpu())
245 | vertices.append(opdict['verts'].detach().cpu())
246 | rendered_images_cat = torch.cat(rendered_images, dim=0)
247 | vertices_cat = torch.cat(vertices, dim=0)
248 | return rendered_images_cat, vertices_cat
249 |
250 |
251 | def get_coefficients(coefficient_info):
252 | coefficient_dict = {}
253 | coefficient_dict['pose'] = coefficient_info[:, :6]
254 | coefficient_dict['exp'] = coefficient_info[:, 6:56]
255 | coefficient_dict['shape'] = coefficient_info[:, 56:156]
256 | coefficient_dict['tex'] = coefficient_info[:, 156:206]
257 | coefficient_dict['cam'] = coefficient_info[:, 206:209]
258 | # coefficient_dict['light'] = coefficient_info[:, 209:236]
259 | return coefficient_dict
260 |
--------------------------------------------------------------------------------
/external/spectre/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit_image==0.19.3
2 | kornia==0.6.6
3 | chumpy==0.70
4 | librosa==0.9.2
5 | av==9.2.0
6 | loguru==0.6.0
7 | tensorboard==2.9.1
8 | pytorch_lightning==1.5
9 | opencv-python==4.6.0.66
10 | phonemizer==3.2.1
11 | jiwer==2.3.0
--------------------------------------------------------------------------------
/external/spectre/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/external/spectre/src/__init__.py
--------------------------------------------------------------------------------
/external/spectre/src/models/FLAME.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import torch
17 | import torch.nn as nn
18 | import numpy as np
19 | import pickle
20 | import torch.nn.functional as F
21 |
22 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler
23 |
24 | def to_tensor(array, dtype=torch.float32):
25 | if 'torch.tensor' not in str(type(array)):
26 | return torch.tensor(array, dtype=dtype)
27 | def to_np(array, dtype=np.float32):
28 | if 'scipy.sparse' in str(type(array)):
29 | array = array.todense()
30 | return np.array(array, dtype=dtype)
31 |
32 | class Struct(object):
33 | def __init__(self, **kwargs):
34 | for key, val in kwargs.items():
35 | setattr(self, key, val)
36 |
37 | class FLAME(nn.Module):
38 | """
39 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py
40 | Given flame parameters this class generates a differentiable FLAME function
41 | which outputs the a mesh and 2D/3D facial landmarks
42 | """
43 | def __init__(self, config):
44 | super(FLAME, self).__init__()
45 | # print("creating the FLAME Decoder")
46 | with open(config.flame_model_path, 'rb') as f:
47 | ss = pickle.load(f, encoding='latin1')
48 | flame_model = Struct(**ss)
49 |
50 | self.dtype = torch.float32
51 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long))
52 | # The vertices of the template model
53 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype))
54 | # The shape components and expression
55 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype)
56 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2)
57 | self.register_buffer('shapedirs', shapedirs)
58 | # The pose components
59 | num_pose_basis = flame_model.posedirs.shape[-1]
60 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T
61 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype))
62 | #
63 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype))
64 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1
65 | self.register_buffer('parents', parents)
66 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype))
67 |
68 | # Fixing Eyeball and neck rotation
69 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False)
70 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose,
71 | requires_grad=False))
72 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False)
73 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose,
74 | requires_grad=False))
75 |
76 | # Static and Dynamic Landmark embeddings for FLAME
77 | lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1')
78 | lmk_embeddings = lmk_embeddings[()]
79 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long())
80 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype))
81 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long())
82 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype))
83 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long())
84 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype))
85 |
86 | neck_kin_chain = []; NECK_IDX=1
87 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long)
88 | while curr_idx != -1:
89 | neck_kin_chain.append(curr_idx)
90 | curr_idx = self.parents[curr_idx]
91 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain))
92 |
93 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx,
94 | dynamic_lmk_b_coords,
95 | neck_kin_chain, dtype=torch.float32):
96 | """
97 | Selects the face contour depending on the reletive position of the head
98 | Input:
99 | vertices: N X num_of_vertices X 3
100 | pose: N X full pose
101 | dynamic_lmk_faces_idx: The list of contour face indexes
102 | dynamic_lmk_b_coords: The list of contour barycentric weights
103 | neck_kin_chain: The tree to consider for the relative rotation
104 | dtype: Data type
105 | return:
106 | The contour face indexes and the corresponding barycentric weights
107 | """
108 |
109 | batch_size = pose.shape[0]
110 |
111 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
112 | neck_kin_chain)
113 | rot_mats = batch_rodrigues(
114 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
115 |
116 | rel_rot_mat = torch.eye(3, device=pose.device,
117 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1)
118 | for idx in range(len(neck_kin_chain)):
119 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
120 |
121 | y_rot_angle = torch.round(
122 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
123 | max=39)).to(dtype=torch.long)
124 |
125 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
126 | mask = y_rot_angle.lt(-39).to(dtype=torch.long)
127 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
128 | y_rot_angle = (neg_mask * neg_vals +
129 | (1 - neg_mask) * y_rot_angle)
130 |
131 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
132 | 0, y_rot_angle)
133 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
134 | 0, y_rot_angle)
135 | return dyn_lmk_faces_idx, dyn_lmk_b_coords
136 |
137 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords):
138 | """
139 | Calculates landmarks by barycentric interpolation
140 | Input:
141 | vertices: torch.tensor NxVx3, dtype = torch.float32
142 | The tensor of input vertices
143 | faces: torch.tensor (N*F)x3, dtype = torch.long
144 | The faces of the mesh
145 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long
146 | The tensor with the indices of the faces used to calculate the
147 | landmarks.
148 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32
149 | The tensor of barycentric coordinates that are used to interpolate
150 | the landmarks
151 |
152 | Returns:
153 | landmarks: torch.tensor NxLx3, dtype = torch.float32
154 | The coordinates of the landmarks for each mesh in the batch
155 | """
156 | # Extract the indices of the vertices for each face
157 | # NxLx3
158 | batch_size, num_verts = vertices.shape[:dd2]
159 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
160 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1)
161 |
162 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to(
163 | device=vertices.device) * num_verts
164 |
165 | lmk_vertices = vertices.view(-1, 3)[lmk_faces]
166 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
167 | return landmarks
168 |
169 | def seletec_3d68(self, vertices):
170 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
171 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1),
172 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1))
173 | return landmarks3d
174 |
175 | def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None):
176 | """
177 | Input:
178 | shape_params: N X number of shape parameters
179 | expression_params: N X number of expression parameters
180 | pose_params: N X number of pose parameters (6)
181 | return:d
182 | vertices: N X V X 3
183 | landmarks: N X number of landmarks X 3
184 | """
185 | batch_size = shape_params.shape[0]
186 | if pose_params is None:
187 | pose_params = self.eye_pose.expand(batch_size, -1)
188 | if eye_pose_params is None:
189 | eye_pose_params = self.eye_pose.expand(batch_size, -1)
190 | betas = torch.cat([shape_params, expression_params], dim=1)
191 | full_pose = torch.cat([pose_params[:, :3], self.neck_pose.expand(batch_size, -1), pose_params[:, 3:], eye_pose_params], dim=1)
192 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
193 |
194 | vertices, _ = lbs(betas, full_pose, template_vertices,
195 | self.shapedirs, self.posedirs,
196 | self.J_regressor, self.parents,
197 | self.lbs_weights, dtype=self.dtype)
198 |
199 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
200 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
201 |
202 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
203 | full_pose, self.dynamic_lmk_faces_idx,
204 | self.dynamic_lmk_bary_coords,
205 | self.neck_kin_chain, dtype=self.dtype)
206 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
207 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1)
208 |
209 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor,
210 | lmk_faces_idx,
211 | lmk_bary_coords)
212 | bz = vertices.shape[0]
213 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor,
214 | self.full_lmk_faces_idx.repeat(bz, 1),
215 | self.full_lmk_bary_coords.repeat(bz, 1, 1))
216 | return vertices, landmarks2d, landmarks3d
217 |
218 | class FLAMETex(nn.Module):
219 | """
220 | FLAME texture:
221 | https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64
222 | FLAME texture converted from BFM:
223 | https://github.com/TimoBolkart/BFM_to_FLAME
224 | """
225 | def __init__(self, config):
226 | super(FLAMETex, self).__init__()
227 | if config.tex_type == 'BFM':
228 | mu_key = 'MU'
229 | pc_key = 'PC'
230 | n_pc = 199
231 | tex_path = config.tex_path
232 | tex_space = np.load(tex_path)
233 | texture_mean = tex_space[mu_key].reshape(1, -1)
234 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)
235 |
236 | elif config.tex_type == 'FLAME':
237 | mu_key = 'mean'
238 | pc_key = 'tex_dir'
239 | n_pc = 200
240 | tex_path = config.flame_tex_path
241 | tex_space = np.load(tex_path)
242 | texture_mean = tex_space[mu_key].reshape(1, -1)/255.
243 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)/255.
244 | else:
245 | print('texture type ', config.tex_type, 'not exist!')
246 | raise NotImplementedError
247 |
248 | n_tex = config.n_tex
249 | num_components = texture_basis.shape[1]
250 | texture_mean = torch.from_numpy(texture_mean).float()[None,...]
251 | texture_basis = torch.from_numpy(texture_basis[:,:n_tex]).float()[None,...]
252 | self.register_buffer('texture_mean', texture_mean)
253 | self.register_buffer('texture_basis', texture_basis)
254 |
255 |
256 | def forward(self, texcode):
257 | '''
258 | texcode: [batchsize, n_tex]
259 | texture: [bz, 3, 256, 256], range: 0-1
260 | '''
261 |
262 | bs = texcode.shape[0]
263 | texcode = texcode[:1]
264 |
265 | # we use the same (first frame) texture for all frames
266 |
267 | texture = self.texture_mean + (self.texture_basis*texcode[:,None,:]).sum(-1)
268 |
269 | texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0,3,1,2)
270 | texture = F.interpolate(texture, [256, 256])
271 | texture = texture[:,[2,1,0], :,:].repeat(bs,1,1,1)
272 | return texture
273 |
--------------------------------------------------------------------------------
/external/spectre/src/models/encoders.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch.nn as nn
3 | import torch
4 | import torch.nn.functional as F
5 | from . import resnet
6 |
7 |
8 | class PerceptualEncoder(nn.Module):
9 | def __init__(self, outsize, cfg):
10 | super(PerceptualEncoder, self).__init__()
11 | if cfg.backbone == "mobilenetv2":
12 | self.encoder = torch.hub.load('pytorch/vision:v0.8.1', 'mobilenet_v2', pretrained=True)
13 | feature_size = 1280
14 | elif cfg.backbone == "resnet50":
15 | self.encoder = resnet.load_ResNet50Model() #out: 2048
16 | feature_size = 2048
17 |
18 | ### regressor
19 | self.temporal = nn.Sequential(
20 | nn.Conv1d(in_channels=feature_size, out_channels=256, kernel_size=5, stride=1, padding=2),
21 | nn.BatchNorm1d(256),
22 | nn.ReLU()
23 | )
24 |
25 | self.layers = nn.Sequential(
26 | nn.Linear(256, 53),
27 | )
28 |
29 | self.backbone = cfg.backbone
30 |
31 | def forward(self, inputs):
32 | if self.backbone == 'resnet50':
33 | features = self.encoder(inputs).squeeze(-1).squeeze(-1)
34 | else:
35 | features = self.encoder.features(inputs)
36 | features = nn.functional.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
37 |
38 | features = features
39 | features = features.permute(1,0).unsqueeze(0)
40 |
41 | features = self.temporal(features)
42 |
43 | features = features.squeeze(0).permute(1,0)
44 |
45 | parameters = self.layers(features)
46 |
47 | parameters[...,50] = F.relu(parameters[...,50]) # jaw x is highly improbably negative and can introduce artifacts
48 |
49 | return parameters[...,:50], parameters[...,50:]
50 |
51 |
52 | class ResnetEncoder(nn.Module):
53 | def __init__(self, outsize):
54 | super(ResnetEncoder, self).__init__()
55 |
56 | feature_size = 2048
57 |
58 | self.encoder = resnet.load_ResNet50Model() #out: 2048
59 | ### regressor
60 | self.layers = nn.Sequential(
61 | nn.Linear(feature_size, 1024),
62 | nn.ReLU(),
63 | nn.Linear(1024, outsize)
64 | )
65 |
66 | def forward(self, inputs):
67 | features = self.encoder(inputs)
68 | parameters = self.layers(features)
69 |
70 | return parameters
71 |
72 |
--------------------------------------------------------------------------------
/external/spectre/src/models/expression_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import torch.nn as nn
17 | from torchvision import models
18 | from . import resnet
19 |
20 |
21 |
22 | class ExpressionLossNet(nn.Module):
23 | """ Code borrowed from EMOCA https://github.com/radekd91/emoca """
24 | def __init__(self):
25 | super(ExpressionLossNet, self).__init__()
26 |
27 | self.backbone = resnet.load_ResNet50Model() #out: 2048
28 |
29 | self.linear = nn.Sequential(
30 | nn.Linear(2048, 10))
31 |
32 | def forward2(self, inputs):
33 | features = self.backbone(inputs)
34 | out = self.linear(features)
35 | return features, out
36 |
37 | def forward(self, inputs):
38 | features = self.backbone(inputs)
39 | return features
40 |
--------------------------------------------------------------------------------
/external/spectre/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Soubhik Sanyal
3 | Copyright (c) 2019, Soubhik Sanyal
4 | All rights reserved.
5 | Loads different resnet models
6 | """
7 | '''
8 | file: Resnet.py
9 | date: 2018_05_02
10 | author: zhangxiong(1025679612@qq.com)
11 | mark: copied from pytorch source code
12 | '''
13 |
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch
17 | from torch.nn.parameter import Parameter
18 | import torch.optim as optim
19 | import numpy as np
20 | import math
21 | import torchvision
22 |
23 | class ResNet(nn.Module):
24 | def __init__(self, block, layers, num_classes=1000):
25 | self.inplanes = 64
26 | super(ResNet, self).__init__()
27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
28 | bias=False)
29 | self.bn1 = nn.BatchNorm2d(64)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
32 | self.layer1 = self._make_layer(block, 64, layers[0])
33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
36 | self.avgpool = nn.AvgPool2d(7, stride=1)
37 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
38 |
39 | for m in self.modules():
40 | if isinstance(m, nn.Conv2d):
41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
42 | m.weight.data.normal_(0, math.sqrt(2. / n))
43 | elif isinstance(m, nn.BatchNorm2d):
44 | m.weight.data.fill_(1)
45 | m.bias.data.zero_()
46 |
47 | def _make_layer(self, block, planes, blocks, stride=1):
48 | downsample = None
49 | if stride != 1 or self.inplanes != planes * block.expansion:
50 | downsample = nn.Sequential(
51 | nn.Conv2d(self.inplanes, planes * block.expansion,
52 | kernel_size=1, stride=stride, bias=False),
53 | nn.BatchNorm2d(planes * block.expansion),
54 | )
55 |
56 | layers = []
57 | layers.append(block(self.inplanes, planes, stride, downsample))
58 | self.inplanes = planes * block.expansion
59 | for i in range(1, blocks):
60 | layers.append(block(self.inplanes, planes))
61 |
62 | return nn.Sequential(*layers)
63 |
64 | def forward(self, x):
65 | x = self.conv1(x)
66 | x = self.bn1(x)
67 | x = self.relu(x)
68 | x = self.maxpool(x)
69 |
70 | x = self.layer1(x)
71 | x = self.layer2(x)
72 | x = self.layer3(x)
73 | x1 = self.layer4(x)
74 |
75 | x2 = self.avgpool(x1)
76 | x2 = x2.view(x2.size(0), -1)
77 | # x = self.fc(x)
78 | ## x2: [bz, 2048] for shape
79 | ## x1: [bz, 2048, 7, 7] for texture
80 | return x2
81 |
82 | class Bottleneck(nn.Module):
83 | expansion = 4
84 |
85 | def __init__(self, inplanes, planes, stride=1, downsample=None):
86 | super(Bottleneck, self).__init__()
87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
88 | self.bn1 = nn.BatchNorm2d(planes)
89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
90 | padding=1, bias=False)
91 | self.bn2 = nn.BatchNorm2d(planes)
92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
93 | self.bn3 = nn.BatchNorm2d(planes * 4)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.downsample = downsample
96 | self.stride = stride
97 |
98 | def forward(self, x):
99 | residual = x
100 |
101 | out = self.conv1(x)
102 | out = self.bn1(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv2(out)
106 | out = self.bn2(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv3(out)
110 | out = self.bn3(out)
111 |
112 | if self.downsample is not None:
113 | residual = self.downsample(x)
114 |
115 | out += residual
116 | out = self.relu(out)
117 |
118 | return out
119 |
120 | def conv3x3(in_planes, out_planes, stride=1):
121 | """3x3 convolution with padding"""
122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
123 | padding=1, bias=False)
124 |
125 | class BasicBlock(nn.Module):
126 | expansion = 1
127 |
128 | def __init__(self, inplanes, planes, stride=1, downsample=None):
129 | super(BasicBlock, self).__init__()
130 | self.conv1 = conv3x3(inplanes, planes, stride)
131 | self.bn1 = nn.BatchNorm2d(planes)
132 | self.relu = nn.ReLU(inplace=True)
133 | self.conv2 = conv3x3(planes, planes)
134 | self.bn2 = nn.BatchNorm2d(planes)
135 | self.downsample = downsample
136 | self.stride = stride
137 |
138 | def forward(self, x):
139 | residual = x
140 |
141 | out = self.conv1(x)
142 | out = self.bn1(out)
143 | out = self.relu(out)
144 |
145 | out = self.conv2(out)
146 | out = self.bn2(out)
147 |
148 | if self.downsample is not None:
149 | residual = self.downsample(x)
150 |
151 | out += residual
152 | out = self.relu(out)
153 |
154 | return out
155 |
156 | def copy_parameter_from_resnet(model, resnet_dict):
157 | cur_state_dict = model.state_dict()
158 | # import ipdb; ipdb.set_trace()
159 | for name, param in list(resnet_dict.items())[0:None]:
160 | if name not in cur_state_dict:
161 | # print(name, ' not available in reconstructed resnet')
162 | continue
163 | if isinstance(param, Parameter):
164 | param = param.data
165 | try:
166 | cur_state_dict[name].copy_(param)
167 | except:
168 | # print(name, ' is inconsistent!')
169 | continue
170 | # print('copy resnet state dict finished!')
171 | # import ipdb; ipdb.set_trace()
172 |
173 |
174 | def load_ResNet50Model():
175 | model = ResNet(Bottleneck, [3, 4, 6, 3])
176 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = True).state_dict())
177 | return model
178 |
179 | def load_ResNet101Model():
180 | model = ResNet(Bottleneck, [3, 4, 23, 3])
181 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict())
182 | return model
183 |
184 | def load_ResNet152Model():
185 | model = ResNet(Bottleneck, [3, 8, 36, 3])
186 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict())
187 | return model
188 |
189 | # model.load_state_dict(checkpoint['model_state_dict'])
190 |
191 |
192 | ######## Unet
193 |
194 | class DoubleConv(nn.Module):
195 | """(convolution => [BN] => ReLU) * 2"""
196 |
197 | def __init__(self, in_channels, out_channels):
198 | super().__init__()
199 | self.double_conv = nn.Sequential(
200 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
201 | nn.BatchNorm2d(out_channels),
202 | nn.ReLU(inplace=True),
203 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
204 | nn.BatchNorm2d(out_channels),
205 | nn.ReLU(inplace=True)
206 | )
207 |
208 | def forward(self, x):
209 | return self.double_conv(x)
210 |
211 |
212 | class Down(nn.Module):
213 | """Downscaling with maxpool then double conv"""
214 |
215 | def __init__(self, in_channels, out_channels):
216 | super().__init__()
217 | self.maxpool_conv = nn.Sequential(
218 | nn.MaxPool2d(2),
219 | DoubleConv(in_channels, out_channels)
220 | )
221 |
222 | def forward(self, x):
223 | return self.maxpool_conv(x)
224 |
225 |
226 | class Up(nn.Module):
227 | """Upscaling then double conv"""
228 |
229 | def __init__(self, in_channels, out_channels, bilinear=True):
230 | super().__init__()
231 |
232 | # if bilinear, use the normal convolutions to reduce the number of channels
233 | if bilinear:
234 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
235 | else:
236 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
237 |
238 | self.conv = DoubleConv(in_channels, out_channels)
239 |
240 | def forward(self, x1, x2):
241 | x1 = self.up(x1)
242 | # input is CHW
243 | diffY = x2.size()[2] - x1.size()[2]
244 | diffX = x2.size()[3] - x1.size()[3]
245 |
246 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
247 | diffY // 2, diffY - diffY // 2])
248 | # if you have padding issues, see
249 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
250 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
251 | x = torch.cat([x2, x1], dim=1)
252 | return self.conv(x)
253 |
254 |
255 | class OutConv(nn.Module):
256 | def __init__(self, in_channels, out_channels):
257 | super(OutConv, self).__init__()
258 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
259 |
260 | def forward(self, x):
261 | return self.conv(x)
--------------------------------------------------------------------------------
/external/spectre/src/spectre.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # Using this computer program means that you agree to the terms
6 | # in the LICENSE file included with this software distribution.
7 | # Any use not explicitly granted by the LICENSE is prohibited.
8 | #
9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
11 | # for Intelligent Systems. All rights reserved.
12 | #
13 | # For comments or questions, please email us at deca@tue.mpg.de
14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
15 |
16 | import os
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 | from .models.encoders import PerceptualEncoder
22 | from .utils.renderer import SRenderY, set_rasterizer
23 | from .models.encoders import ResnetEncoder
24 | from .models.FLAME import FLAME, FLAMETex
25 | from .utils import util
26 | from .utils.tensor_cropper import transform_points
27 | from skimage.io import imread
28 | torch.backends.cudnn.benchmark = True
29 | import numpy as np
30 |
31 | class SPECTRE(nn.Module):
32 | def __init__(self, config=None, device='cuda'):
33 | super(SPECTRE, self).__init__()
34 | self.cfg = config
35 | self.device = device
36 | self.image_size = self.cfg.dataset.image_size
37 | self.uv_size = self.cfg.model.uv_size
38 | self._create_model(self.cfg.model)
39 | self._setup_renderer(self.cfg.model)
40 |
41 |
42 | def _setup_renderer(self, model_cfg):
43 | set_rasterizer(self.cfg.rasterizer_type)
44 | self.render = SRenderY(self.image_size, obj_filename=model_cfg.topology_path, uv_size=model_cfg.uv_size, rasterizer_type=self.cfg.rasterizer_type).to(self.device)
45 | # face mask for rendering details
46 | mask = imread(model_cfg.face_eye_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous()
47 | self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device)
48 | mask = imread(model_cfg.face_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous()
49 | self.uv_face_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device)
50 | # displacement correction
51 | fixed_dis = np.load(model_cfg.fixed_displacement_path)
52 | self.fixed_uv_dis = torch.tensor(fixed_dis).float().to(self.device)
53 | # mean texture
54 | mean_texture = imread(model_cfg.mean_tex_path).astype(np.float32)/255.; mean_texture = torch.from_numpy(mean_texture.transpose(2,0,1))[None,:,:,:].contiguous()
55 | self.mean_texture = F.interpolate(mean_texture, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device)
56 | # dense mesh template, for save detail mesh
57 | self.dense_template = np.load(model_cfg.dense_template_path, allow_pickle=True, encoding='latin1').item()
58 |
59 | def _create_model(self, model_cfg):
60 | # set up parameters
61 | self.n_param = model_cfg.n_shape + model_cfg.n_tex + model_cfg.n_exp + model_cfg.n_pose + model_cfg.n_cam + model_cfg.n_light
62 | self.n_cond = model_cfg.n_exp + 3 # exp + jaw pose
63 | self.num_list = [model_cfg.n_shape, model_cfg.n_tex, model_cfg.n_exp, model_cfg.n_pose, model_cfg.n_cam,
64 | model_cfg.n_light]
65 | self.param_dict = {i: model_cfg.get('n_' + i) for i in model_cfg.param_list}
66 |
67 | # encoders
68 | self.E_flame = ResnetEncoder(outsize=self.n_param).to(self.device)
69 |
70 | self.E_expression = PerceptualEncoder(model_cfg.n_exp, model_cfg).to(self.device)
71 |
72 | # decoders
73 | self.flame = FLAME(model_cfg).to(self.device)
74 | if model_cfg.use_tex:
75 | self.flametex = FLAMETex(model_cfg).to(self.device)
76 |
77 | # resume model
78 | model_path = self.cfg.pretrained_modelpath
79 | if os.path.exists(model_path):
80 | # print(f'trained model found. load {model_path}')
81 | checkpoint = torch.load(model_path)
82 |
83 | if 'state_dict' in checkpoint.keys():
84 | self.checkpoint = checkpoint['state_dict']
85 | else:
86 | self.checkpoint = checkpoint
87 |
88 | processed_checkpoint = {}
89 | processed_checkpoint["E_flame"] = {}
90 | processed_checkpoint["E_expression"] = {}
91 | if 'deca' in list(self.checkpoint.keys())[0]:
92 | for key in self.checkpoint.keys():
93 | # print(key)
94 | k = key.replace("deca.","")
95 | if "E_flame" in key:
96 | processed_checkpoint["E_flame"][k.replace("E_flame.","")] = self.checkpoint[key]#.replace("E_flame","")
97 | elif "E_expression" in key:
98 | processed_checkpoint["E_expression"][k.replace("E_expression.","")] = self.checkpoint[key]#.replace("E_flame","")
99 | else:
100 | pass
101 |
102 | else:
103 | processed_checkpoint = self.checkpoint
104 |
105 |
106 | self.E_flame.load_state_dict(processed_checkpoint['E_flame'], strict=True)
107 | try:
108 | m,u = self.E_expression.load_state_dict(processed_checkpoint['E_expression'], strict=True)
109 | # print('Missing keys', m)
110 | # print('Unexpected keys', u)
111 | # pass
112 | except Exception as e:
113 | print(f'Missing keys {e} in expression encoder weights. If starting training from scratch this is normal.')
114 | else:
115 | raise(f'please check model path: {model_path}')
116 |
117 | # eval mode
118 | self.E_flame.eval()
119 |
120 | self.E_expression.eval()
121 |
122 | self.E_flame.requires_grad_(False)
123 |
124 |
125 | def decompose_code(self, code, num_dict):
126 | ''' Convert a flattened parameter vector to a dictionary of parameters
127 | code_dict.keys() = ['shape', 'tex', 'exp', 'pose', 'cam', 'light']
128 | '''
129 | code_dict = {}
130 | start = 0
131 | for key in num_dict:
132 | end = start + int(num_dict[key])
133 | code_dict[key] = code[:, start:end]
134 | start = end
135 | if key == 'light':
136 | code_dict[key] = code_dict[key].reshape(code_dict[key].shape[0], 9, 3)
137 | return code_dict
138 |
139 | def encode(self, images):
140 | with torch.no_grad():
141 | parameters = self.E_flame(images)
142 |
143 | codedict = self.decompose_code(parameters, self.param_dict)
144 | deca_exp = codedict['exp'].clone()
145 | deca_jaw = codedict['pose'][:,3:].clone()
146 |
147 | codedict['images'] = images
148 |
149 | codedict['exp'], jaw = self.E_expression(images)
150 | codedict['pose'][:, 3:] = jaw
151 |
152 | return codedict, deca_exp, deca_jaw
153 |
154 |
155 | def decode(self, codedict, rendering=True, vis_lmk=False, return_vis=False,
156 | render_orig=False, original_image=None, tform=None):
157 | # images = codedict['images']
158 | # batch_size = images.shape[0]
159 | batch_size = 1
160 |
161 | ## decode
162 | verts, landmarks2d, landmarks3d = self.flame(shape_params=codedict['shape'], expression_params=codedict['exp'],
163 | pose_params=codedict['pose'])
164 | if self.cfg.model.use_tex:
165 | albedo = self.flametex(codedict['tex']).detach()
166 | else:
167 | albedo = torch.zeros([batch_size, 3, self.uv_size, self.uv_size], device=self.device)
168 | landmarks3d_world = landmarks3d.clone()
169 |
170 | ## projection
171 | landmarks2d = util.batch_orth_proj(landmarks2d, codedict['cam'])[:, :, :2];
172 | landmarks2d[:, :, 1:] = -landmarks2d[:, :,
173 | 1:]
174 | landmarks3d = util.batch_orth_proj(landmarks3d, codedict['cam']);
175 | landmarks3d[:, :, 1:] = -landmarks3d[:, :,
176 | 1:]
177 | trans_verts = util.batch_orth_proj(verts, codedict['cam']);
178 | trans_verts[:, :, 1:] = -trans_verts[:, :, 1:]
179 | opdict = {
180 | 'verts': verts,
181 | 'trans_verts': trans_verts,
182 | 'landmarks2d': landmarks2d,
183 | 'landmarks3d': landmarks3d,
184 | 'landmarks3d_world': landmarks3d_world,
185 | }
186 |
187 | if rendering and render_orig and original_image is not None and tform is not None:
188 | points_scale = [self.image_size, self.image_size]
189 | _, _, h, w = original_image.shape
190 | trans_verts = transform_points(trans_verts, tform, points_scale, [h, w])
191 | landmarks2d = transform_points(landmarks2d, tform, points_scale, [h, w])
192 | landmarks3d = transform_points(landmarks3d, tform, points_scale, [h, w])
193 | background = images
194 | else:
195 | h, w = self.image_size, self.image_size
196 | background = None
197 |
198 |
199 | if rendering:
200 | if self.cfg.model.use_tex:
201 | ops = self.render(verts, trans_verts, albedo, codedict['light'])
202 | ## output
203 | opdict['predicted_inner_mouth'] = ops['predicted_inner_mouth']
204 | opdict['grid'] = ops['grid']
205 | opdict['rendered_images'] = ops['images']
206 | opdict['alpha_images'] = ops['alpha_images']
207 | opdict['normal_images'] = ops['normal_images']
208 | opdict['images'] = images
209 |
210 | else:
211 | shape_images, _, grid, alpha_images, pos_mask = self.render.render_shape(verts, trans_verts, h=h, w=w,
212 | images=background,
213 | return_grid=True,
214 | return_pos=True)
215 |
216 | opdict['rendered_images'] = shape_images
217 |
218 | if self.cfg.model.use_tex:
219 | opdict['albedo'] = albedo
220 |
221 | if vis_lmk:
222 | landmarks3d_vis = self.visofp(ops['transformed_normals']) # /self.image_size
223 | landmarks3d = torch.cat([landmarks3d, landmarks3d_vis], dim=2)
224 | opdict['landmarks3d'] = landmarks3d
225 |
226 | if return_vis:
227 | ## render shape
228 | shape_images, _, grid, alpha_images, pos_mask = self.render.render_shape(verts, trans_verts, h=h, w=w,
229 | images=background, return_grid=True, return_pos=True)
230 |
231 | # opdict['uv_texture_gt'] = uv_texture_gt
232 | visdict = {
233 | # 'inputs': images,
234 | 'landmarks2d': util.tensor_vis_landmarks(images, landmarks2d),
235 | 'landmarks3d': util.tensor_vis_landmarks(images, landmarks3d),
236 | 'shape_images': shape_images,
237 | # 'rendered_images': ops['images']
238 | }
239 |
240 | return opdict, visdict
241 |
242 | else:
243 | return opdict
244 |
245 | def train(self):
246 | self.E_expression.train()
247 |
248 | self.E_flame.eval()
249 |
250 |
251 | def eval(self):
252 | self.E_expression.eval()
253 | self.E_flame.eval()
254 |
255 |
256 | def model_dict(self):
257 | return {
258 | 'E_flame': self.E_flame.state_dict(),
259 | 'E_expression': self.E_expression.state_dict(),
260 | }
261 |
--------------------------------------------------------------------------------
/external/spectre/src/utils/lossfunc.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def l2_distance(verts1, verts2):
8 | return torch.sqrt(((verts1 - verts2)**2).sum(2)).mean(1).mean()
9 |
10 | ### ------------------------------------- Losses/Regularizations for vertices
11 | def batch_kp_2d_l1_loss(real_2d_kp, predicted_2d_kp, weights=None):
12 | """
13 | Computes the l1 loss between the ground truth keypoints and the predicted keypoints
14 | Inputs:
15 | kp_gt : N x K x 3
16 | kp_pred: N x K x 2
17 | """
18 | if weights is not None:
19 | real_2d_kp[:,:,2] = weights[None,:]*real_2d_kp[:,:,2]
20 | kp_gt = real_2d_kp.view(-1, 3)
21 | kp_pred = predicted_2d_kp.contiguous().view(-1, 2)
22 | vis = kp_gt[:, 2]
23 | k = torch.sum(vis) * 2.0 + 1e-8
24 |
25 | dif_abs = torch.abs(kp_gt[:, :2] - kp_pred).sum(1)
26 |
27 | return torch.matmul(dif_abs, vis) * 1.0 / k
28 |
29 | def landmark_loss(predicted_landmarks, landmarks_gt, weight=1.):
30 | if torch.is_tensor(landmarks_gt) is not True:
31 | real_2d = torch.cat(landmarks_gt).cuda()
32 | else:
33 | real_2d = torch.cat([landmarks_gt, torch.ones((landmarks_gt.shape[0], 68, 1)).cuda()], dim=-1)
34 |
35 | loss_lmk_2d = batch_kp_2d_l1_loss(real_2d, predicted_landmarks)
36 | return loss_lmk_2d * weight
37 |
38 |
39 | def weighted_landmark_loss(predicted_landmarks, landmarks_gt, weight=1.):
40 | #smaller inner landmark weights
41 | # (predicted_theta, predicted_verts, predicted_landmarks) = ringnet_outputs[-1]
42 | # import ipdb; ipdb.set_trace()
43 | real_2d = landmarks_gt
44 | weights = torch.ones((68,)).cuda()
45 | weights[5:7] = 2
46 | weights[10:12] = 2
47 | # nose points
48 | weights[27:36] = 1.5
49 | weights[30] = 3
50 | weights[31] = 3
51 | weights[35] = 3
52 |
53 | # set mouth to zero
54 | weights[60:68] = 0
55 | weights[48:60] = 0
56 | weights[48] = 0
57 | weights[54] = 0
58 |
59 |
60 | # weights[36:48] = 0 # these are eyes
61 |
62 | loss_lmk_2d = batch_kp_2d_l1_loss(real_2d, predicted_landmarks, weights)
63 | return loss_lmk_2d * weight
64 |
65 |
66 | def rel_dis(landmarks):
67 |
68 | lip_right = landmarks[:, [57, 51, 48, 60, 61, 62, 63], :]
69 | lip_left = landmarks[:, [8, 33, 54, 64, 67, 66, 65], :]
70 |
71 | # lip_right = landmarks[:, [61, 62, 63], :]
72 | # lip_left = landmarks[:, [67, 66, 65], :]
73 |
74 | dis = torch.sqrt(((lip_right - lip_left) ** 2).sum(2)) # [bz, 4]
75 |
76 | return dis
77 |
78 | def relative_landmark_loss(predicted_landmarks, landmarks_gt, weight=1.):
79 | if torch.is_tensor(landmarks_gt) is not True:
80 | real_2d = torch.cat(landmarks_gt)#.cuda()
81 | else:
82 | real_2d = torch.cat([landmarks_gt, torch.ones((landmarks_gt.shape[0], 68, 1)).to(device=predicted_landmarks.device) #.cuda()
83 | ], dim=-1)
84 | pred_lipd = rel_dis(predicted_landmarks[:, :, :2])
85 | gt_lipd = rel_dis(real_2d[:, :, :2])
86 |
87 | loss = (pred_lipd - gt_lipd).abs().mean()
88 | # loss = F.mse_loss(pred_lipd, gt_lipd)
89 |
90 | return loss.mean()
91 |
92 |
--------------------------------------------------------------------------------
/external/spectre/src/utils/rotation_converter.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | ''' Rotation Converter
4 | Repre: euler angle(3), angle axis(3), rotation matrix(3x3), quaternion(4)
5 | ref: https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/core/conversions.html#
6 | "pi",
7 | "rad2deg",
8 | "deg2rad",
9 | # "angle_axis_to_rotation_matrix", batch_rodrigues
10 | "rotation_matrix_to_angle_axis",
11 | "rotation_matrix_to_quaternion",
12 | "quaternion_to_angle_axis",
13 | # "angle_axis_to_quaternion",
14 |
15 | euler2quat_conversion_sanity_batch
16 |
17 | ref: smplx/lbs
18 | batch_rodrigues: axis angle -> matrix
19 | #
20 | '''
21 | pi = torch.Tensor([3.14159265358979323846])
22 |
23 | def rad2deg(tensor):
24 | """Function that converts angles from radians to degrees.
25 |
26 | See :class:`~torchgeometry.RadToDeg` for details.
27 |
28 | Args:
29 | tensor (Tensor): Tensor of arbitrary shape.
30 |
31 | Returns:
32 | Tensor: Tensor with same shape as input.
33 |
34 | Example:
35 | >>> input = tgm.pi * torch.rand(1, 3, 3)
36 | >>> output = tgm.rad2deg(input)
37 | """
38 | if not torch.is_tensor(tensor):
39 | raise TypeError("Input type is not a torch.Tensor. Got {}"
40 | .format(type(tensor)))
41 |
42 | return 180. * tensor / pi.to(tensor.device).type(tensor.dtype)
43 |
44 | def deg2rad(tensor):
45 | """Function that converts angles from degrees to radians.
46 |
47 | See :class:`~torchgeometry.DegToRad` for details.
48 |
49 | Args:
50 | tensor (Tensor): Tensor of arbitrary shape.
51 |
52 | Returns:
53 | Tensor: Tensor with same shape as input.
54 |
55 | Examples::
56 |
57 | >>> input = 360. * torch.rand(1, 3, 3)
58 | >>> output = tgm.deg2rad(input)
59 | """
60 | if not torch.is_tensor(tensor):
61 | raise TypeError("Input type is not a torch.Tensor. Got {}"
62 | .format(type(tensor)))
63 |
64 | return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.
65 |
66 | ######### to quaternion
67 | def euler_to_quaternion(r):
68 | x = r[..., 0]
69 | y = r[..., 1]
70 | z = r[..., 2]
71 |
72 | z = z/2.0
73 | y = y/2.0
74 | x = x/2.0
75 | cz = torch.cos(z)
76 | sz = torch.sin(z)
77 | cy = torch.cos(y)
78 | sy = torch.sin(y)
79 | cx = torch.cos(x)
80 | sx = torch.sin(x)
81 | quaternion = torch.zeros_like(r.repeat(1,2))[..., :4].to(r.device)
82 | quaternion[..., 0] += cx*cy*cz - sx*sy*sz
83 | quaternion[..., 1] += cx*sy*sz + cy*cz*sx
84 | quaternion[..., 2] += cx*cz*sy - sx*cy*sz
85 | quaternion[..., 3] += cx*cy*sz + sx*cz*sy
86 | return quaternion
87 |
88 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
89 | """Convert 3x4 rotation matrix to 4d quaternion vector
90 |
91 | This algorithm is based on algorithm described in
92 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
93 |
94 | Args:
95 | rotation_matrix (Tensor): the rotation matrix to convert.
96 |
97 | Return:
98 | Tensor: the rotation in quaternion
99 |
100 | Shape:
101 | - Input: :math:`(N, 3, 4)`
102 | - Output: :math:`(N, 4)`
103 |
104 | Example:
105 | >>> input = torch.rand(4, 3, 4) # Nx3x4
106 | >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
107 | """
108 | if not torch.is_tensor(rotation_matrix):
109 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
110 | type(rotation_matrix)))
111 |
112 | if len(rotation_matrix.shape) > 3:
113 | raise ValueError(
114 | "Input size must be a three dimensional tensor. Got {}".format(
115 | rotation_matrix.shape))
116 | # if not rotation_matrix.shape[-2:] == (3, 4):
117 | # raise ValueError(
118 | # "Input size must be a N x 3 x 4 tensor. Got {}".format(
119 | # rotation_matrix.shape))
120 |
121 | rmat_t = torch.transpose(rotation_matrix, 1, 2)
122 |
123 | mask_d2 = rmat_t[:, 2, 2] < eps
124 |
125 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
126 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
127 |
128 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
129 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
130 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
131 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
132 | t0_rep = t0.repeat(4, 1).t()
133 |
134 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
135 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
136 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
137 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
138 | t1_rep = t1.repeat(4, 1).t()
139 |
140 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
141 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
142 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
143 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
144 | t2_rep = t2.repeat(4, 1).t()
145 |
146 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
147 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
148 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
149 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
150 | t3_rep = t3.repeat(4, 1).t()
151 |
152 | mask_c0 = mask_d2 * mask_d0_d1.float()
153 | mask_c1 = mask_d2 * (1 - mask_d0_d1.float())
154 | mask_c2 = (1 - mask_d2.float()) * mask_d0_nd1
155 | mask_c3 = (1 - mask_d2.float()) * (1 - mask_d0_nd1.float())
156 | mask_c0 = mask_c0.view(-1, 1).type_as(q0)
157 | mask_c1 = mask_c1.view(-1, 1).type_as(q1)
158 | mask_c2 = mask_c2.view(-1, 1).type_as(q2)
159 | mask_c3 = mask_c3.view(-1, 1).type_as(q3)
160 |
161 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
162 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
163 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
164 | q *= 0.5
165 | return q
166 |
167 | # def angle_axis_to_quaternion(theta):
168 | # batch_size = theta.shape[0]
169 | # l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
170 | # angle = torch.unsqueeze(l1norm, -1)
171 | # normalized = torch.div(theta, angle)
172 | # angle = angle * 0.5
173 | # v_cos = torch.cos(angle)
174 | # v_sin = torch.sin(angle)
175 | # quat = torch.cat([v_cos, v_sin * normalized], dim=1)
176 | # return quat
177 |
178 | def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor:
179 | """Convert an angle axis to a quaternion.
180 |
181 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
182 |
183 | Args:
184 | angle_axis (torch.Tensor): tensor with angle axis.
185 |
186 | Return:
187 | torch.Tensor: tensor with quaternion.
188 |
189 | Shape:
190 | - Input: :math:`(*, 3)` where `*` means, any number of dimensions
191 | - Output: :math:`(*, 4)`
192 |
193 | Example:
194 | >>> angle_axis = torch.rand(2, 4) # Nx4
195 | >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3
196 | """
197 | if not torch.is_tensor(angle_axis):
198 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
199 | type(angle_axis)))
200 |
201 | if not angle_axis.shape[-1] == 3:
202 | raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}"
203 | .format(angle_axis.shape))
204 | # unpack input and compute conversion
205 | a0: torch.Tensor = angle_axis[..., 0:1]
206 | a1: torch.Tensor = angle_axis[..., 1:2]
207 | a2: torch.Tensor = angle_axis[..., 2:3]
208 | theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2
209 |
210 | theta: torch.Tensor = torch.sqrt(theta_squared)
211 | half_theta: torch.Tensor = theta * 0.5
212 |
213 | mask: torch.Tensor = theta_squared > 0.0
214 | ones: torch.Tensor = torch.ones_like(half_theta)
215 |
216 | k_neg: torch.Tensor = 0.5 * ones
217 | k_pos: torch.Tensor = torch.sin(half_theta) / theta
218 | k: torch.Tensor = torch.where(mask, k_pos, k_neg)
219 | w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones)
220 |
221 | quaternion: torch.Tensor = torch.zeros_like(angle_axis)
222 | quaternion[..., 0:1] += a0 * k
223 | quaternion[..., 1:2] += a1 * k
224 | quaternion[..., 2:3] += a2 * k
225 | return torch.cat([w, quaternion], dim=-1)
226 |
227 | #### quaternion to
228 | def quaternion_to_rotation_matrix(quat):
229 | """Convert quaternion coefficients to rotation matrix.
230 | Args:
231 | quat: size = [B, 4] 4 <===>(w, x, y, z)
232 | Returns:
233 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
234 | """
235 | norm_quat = quat
236 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
237 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
238 |
239 | B = quat.size(0)
240 |
241 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
242 | wx, wy, wz = w * x, w * y, w * z
243 | xy, xz, yz = x * y, x * z, y * z
244 |
245 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
246 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
247 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
248 | return rotMat
249 |
250 | def quaternion_to_angle_axis(quaternion: torch.Tensor):
251 | """Convert quaternion vector to angle axis of rotation. TODO: CORRECT
252 |
253 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
254 |
255 | Args:
256 | quaternion (torch.Tensor): tensor with quaternions.
257 |
258 | Return:
259 | torch.Tensor: tensor with angle axis of rotation.
260 |
261 | Shape:
262 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions
263 | - Output: :math:`(*, 3)`
264 |
265 | Example:
266 | >>> quaternion = torch.rand(2, 4) # Nx4
267 | >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
268 | """
269 | if not torch.is_tensor(quaternion):
270 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
271 | type(quaternion)))
272 |
273 | if not quaternion.shape[-1] == 4:
274 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
275 | .format(quaternion.shape))
276 | # unpack input and compute conversion
277 | q1: torch.Tensor = quaternion[..., 1]
278 | q2: torch.Tensor = quaternion[..., 2]
279 | q3: torch.Tensor = quaternion[..., 3]
280 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
281 |
282 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
283 | cos_theta: torch.Tensor = quaternion[..., 0]
284 | two_theta: torch.Tensor = 2.0 * torch.where(
285 | cos_theta < 0.0,
286 | torch.atan2(-sin_theta, -cos_theta),
287 | torch.atan2(sin_theta, cos_theta))
288 |
289 | k_pos: torch.Tensor = two_theta / sin_theta
290 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(quaternion.device)
291 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
292 |
293 | angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(quaternion.device)[..., :3]
294 | angle_axis[..., 0] += q1 * k
295 | angle_axis[..., 1] += q2 * k
296 | angle_axis[..., 2] += q3 * k
297 | return angle_axis
298 |
299 | #### batch converter
300 | def batch_euler2axis(r):
301 | return quaternion_to_angle_axis(euler_to_quaternion(r))
302 |
303 | def batch_euler2matrix(r):
304 | return quaternion_to_rotation_matrix(euler_to_quaternion(r))
305 |
306 | def batch_matrix2euler(rot_mats):
307 | # Calculates rotation matrix to euler angles
308 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
309 | ### only y?
310 | # TODO:
311 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
312 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
313 | return torch.atan2(-rot_mats[:, 2, 0], sy)
314 |
315 | def batch_matrix2axis(rot_mats):
316 | return quaternion_to_angle_axis(rotation_matrix_to_quaternion(rot_mats))
317 |
318 | def batch_axis2matrix(theta):
319 | # angle axis to rotation matrix
320 | # theta N x 3
321 | # return quat2mat(quat)
322 | # batch_rodrigues
323 | return quaternion_to_rotation_matrix(angle_axis_to_quaternion(theta))
324 |
325 | def batch_axis2euler(theta):
326 | return batch_matrix2euler(batch_axis2matrix(theta))
327 |
328 | def batch_axis2euler(r):
329 | return rot_mat_to_euler(batch_rodrigues(r))
330 |
331 |
332 | def batch_orth_proj(X, camera):
333 | '''
334 | X is N x num_pquaternion_to_angle_axisoints x 3
335 | '''
336 | camera = camera.clone().view(-1, 1, 3)
337 | X_trans = X[:, :, :2] + camera[:, :, 1:]
338 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2)
339 | Xn = (camera[:, :, 0:1] * X_trans)
340 | return Xn
341 |
342 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
343 | ''' same as batch_matrix2axis
344 | Calculates the rotation matrices for a batch of rotation vectors
345 | Parameters
346 | ----------
347 | rot_vecs: torch.tensor Nx3
348 | array of N axis-angle vectors
349 | Returns
350 | -------
351 | R: torch.tensor Nx3x3
352 | The rotation matrices for the given axis-angle parameters
353 | '''
354 |
355 | batch_size = rot_vecs.shape[0]
356 | device = rot_vecs.device
357 |
358 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
359 | rot_dir = rot_vecs / angle
360 |
361 | cos = torch.unsqueeze(torch.cos(angle), dim=1)
362 | sin = torch.unsqueeze(torch.sin(angle), dim=1)
363 |
364 | # Bx1 arrays
365 | rx, ry, rz = torch.split(rot_dir, 1, dim=1)
366 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
367 |
368 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
369 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
370 | .view((batch_size, 3, 3))
371 |
372 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
373 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
374 | return rot_mat
375 |
--------------------------------------------------------------------------------
/external/spectre/src/utils/tensor_cropper.py:
--------------------------------------------------------------------------------
1 | '''
2 | crop
3 | for torch tensor
4 | Given image, bbox(center, bboxsize)
5 | return: cropped image, tform(used for transform the keypoint accordingly)
6 | only support crop to squared images
7 | '''
8 | import torch
9 | from kornia.geometry.transform.imgwarp import (
10 | warp_perspective, get_perspective_transform, warp_affine
11 | )
12 |
13 | def points2bbox(points, points_scale=None):
14 | if points_scale:
15 | assert points_scale[0]==points_scale[1]
16 | points = points.clone()
17 | points[:,:,:2] = (points[:,:,:2]*0.5 + 0.5)*points_scale[0]
18 | min_coords, _ = torch.min(points, dim=1)
19 | xmin, ymin = min_coords[:, 0], min_coords[:, 1]
20 | max_coords, _ = torch.max(points, dim=1)
21 | xmax, ymax = max_coords[:, 0], max_coords[:, 1]
22 | center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5
23 |
24 | width = (xmax - xmin)
25 | height = (ymax - ymin)
26 | # Convert the bounding box to a square box
27 | size = torch.max(width, height).unsqueeze(-1)
28 | return center, size
29 |
30 | def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.):
31 | batch_size = center.shape[0]
32 | trans_scale = (torch.rand([batch_size, 2], device=center.device)*2. -1.) * trans_scale
33 | center = center + trans_scale*bbox_size # 0.5
34 | scale = torch.rand([batch_size,1], device=center.device) * (scale[1] - scale[0]) + scale[0]
35 | size = bbox_size*scale
36 | return center, size
37 |
38 | def crop_tensor(image, center, bbox_size, crop_size, interpolation = 'bilinear', align_corners=False):
39 | ''' for batch image
40 | Args:
41 | image (torch.Tensor): the reference tensor of shape BXHxWXC.
42 | center: [bz, 2]
43 | bboxsize: [bz, 1]
44 | crop_size;
45 | interpolation (str): Interpolation flag. Default: 'bilinear'.
46 | align_corners (bool): mode for grid_generation. Default: False. See
47 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details
48 | Returns:
49 | cropped_image
50 | tform
51 | '''
52 | dtype = image.dtype
53 | device = image.device
54 | batch_size = image.shape[0]
55 | # points: top-left, top-right, bottom-right, bottom-left
56 | src_pts = torch.zeros([4,2], dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, -1).contiguous()
57 |
58 | src_pts[:, 0, :] = center - bbox_size*0.5 # / (self.crop_size - 1)
59 | src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5
60 | src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5
61 | src_pts[:, 2, :] = center + bbox_size * 0.5
62 | src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5
63 | src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5
64 |
65 | DST_PTS = torch.tensor([[
66 | [0, 0],
67 | [crop_size - 1, 0],
68 | [crop_size - 1, crop_size - 1],
69 | [0, crop_size - 1],
70 | ]], dtype=dtype, device=device).expand(batch_size, -1, -1)
71 | # estimate transformation between points
72 | dst_trans_src = get_perspective_transform(src_pts, DST_PTS)
73 | # simulate broadcasting
74 | # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1)
75 |
76 | # warp images
77 | cropped_image = warp_affine(
78 | image, dst_trans_src[:, :2, :], (crop_size, crop_size),
79 | flags=interpolation, align_corners=align_corners)
80 |
81 | tform = torch.transpose(dst_trans_src, 2, 1)
82 | # tform = torch.inverse(dst_trans_src)
83 | return cropped_image, tform
84 |
85 | class Cropper(object):
86 | def __init__(self, crop_size, scale=[1,1], trans_scale = 0.):
87 | self.crop_size = crop_size
88 | self.scale = scale
89 | self.trans_scale = trans_scale
90 |
91 | def crop(self, image, points, points_scale=None):
92 | # points to bbox
93 | center, bbox_size = points2bbox(points.clone(), points_scale)
94 | # argument bbox. TODO: add rotation?
95 | center, bbox_size = augment_bbox(center, bbox_size, scale=self.scale, trans_scale=self.trans_scale)
96 | # crop
97 | cropped_image, tform = crop_tensor(image, center, bbox_size, self.crop_size)
98 | return cropped_image, tform
99 |
100 | def transform_points(self, points, tform, points_scale=None, normalize = True):
101 | points_2d = points[:,:,:2]
102 |
103 | #'input points must use original range'
104 | if points_scale:
105 | assert points_scale[0]==points_scale[1]
106 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0]
107 |
108 | batch_size, n_points, _ = points.shape
109 | trans_points_2d = torch.bmm(
110 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1),
111 | tform
112 | )
113 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1)
114 | if normalize:
115 | trans_points[:,:,:2] = trans_points[:,:,:2]/self.crop_size*2 - 1
116 | return trans_points
117 |
118 | def transform_points(points, tform, points_scale=None, out_scale=None):
119 | points_2d = points[:,:,:2]
120 |
121 | #'input points must use original range'
122 | if points_scale:
123 | assert points_scale[0]==points_scale[1]
124 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0]
125 | # import ipdb; ipdb.set_trace()
126 |
127 | batch_size, n_points, _ = points.shape
128 | trans_points_2d = torch.bmm(
129 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1),
130 | tform
131 | )
132 | if out_scale: # h,w of output image size
133 | trans_points_2d[:,:,0] = trans_points_2d[:,:,0]/out_scale[1]*2 - 1
134 | trans_points_2d[:,:,1] = trans_points_2d[:,:,1]/out_scale[0]*2 - 1
135 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1)
136 | return trans_points
--------------------------------------------------------------------------------
/external/spectre/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/external/spectre/utils/__init__.py
--------------------------------------------------------------------------------
/external/spectre/utils/extract_frames_LRS3.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import numpy as np
5 | import torch
6 | from argparse import ArgumentParser
7 |
8 | import sys
9 |
10 |
11 |
12 | def extract(video, tmpl='%06d.jpg'):
13 | os.makedirs(video.replace(".mp4", ""),exist_ok=True)
14 | cmd = 'ffmpeg -i \"{}\" -threads 1 -q:v 0 \"{}/%06d.jpg\"'.format(video,
15 | video.replace(".mp4", ""))
16 | os.system(cmd)
17 |
18 | # os.system("ffmpeg -i {} {} -y".format(videopath, videopath.replace(".mp4",".wav")))
19 |
20 |
21 | # -*- coding: utf-8 -*-
22 |
23 | import os, sys
24 | import cv2
25 | import numpy as np
26 | from time import time
27 | from scipy.io import savemat
28 | import argparse
29 | from tqdm import tqdm
30 | import torch
31 |
32 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
33 | from decalib.deca import DECA
34 | from decalib.datasets import datasets
35 | from decalib.utils import util
36 | from decalib.utils.config import cfg as deca_cfg
37 | import pickle
38 |
39 |
40 |
41 | def video2sequence(video_path, videofolder):
42 | os.makedirs(videofolder, exist_ok=True)
43 | video_name = os.path.splitext(os.path.split(video_path)[-1])[0]
44 | vidcap = cv2.VideoCapture(video_path)
45 | success,image = vidcap.read()
46 | count = 0
47 | imagepath_list = []
48 | while success:
49 | imagepath = os.path.join(videofolder, f'{video_name}_frame{count:05d}.jpg')
50 | cv2.imwrite(imagepath, image) # save frame as JPEG file
51 | success,image = vidcap.read()
52 | count += 1
53 | imagepath_list.append(imagepath)
54 | print('video frames are stored in {}'.format(videofolder))
55 | return imagepath_list
56 |
57 |
58 | from multiprocessing import Pool
59 | from tqdm import tqdm
60 |
61 | def main():
62 | # Parse command-line arguments
63 | parser = ArgumentParser()
64 |
65 | root = "/gpu-data3/filby/LRS3/pretrain"
66 |
67 |
68 | l = list(os.listdir("/gpu-data3/filby/LRS3/pretrain"))
69 | test_list = []
70 | for folder in l:
71 | for file in os.listdir(os.path.join("/gpu-data3/filby/LRS3/pretrain",folder)):
72 |
73 | if file.endswith(".txt"):
74 | test_list.append([os.path.join("/gpu-data3/filby/LRS3/pretrain",folder,file.replace(".txt",".mp4")),os.path.join("/gpu-data3/filby/LRS3/pretrain",folder,file.replace(".txt",".mp4"))])
75 |
76 | # print(test_list[0])
77 | extract(test_list[0])
78 | raise
79 | p = Pool(12)
80 |
81 | for _ in tqdm(p.imap_unordered(video2sequence, test_list), total=len(test_list)):
82 | pass
83 |
84 |
85 | main()
86 |
87 | # import os
88 | # import cv2
89 | # import time
90 | # import numpy as np
91 | # import torch
92 | # from argparse import ArgumentParser
93 | #
94 | # import sys
95 | # sys.path.append("face_parsing")
96 | #
97 | #
98 | # def extract_wav(videopath):
99 | # # print(videopath)
100 | #
101 | # os.system("ffmpeg -i {} {} -y".format(videopath, videopath.replace("/videos/","/wavs/").replace(".mp4",".wav")))
102 | #
103 | # from multiprocessing import Pool
104 | # from tqdm import tqdm
105 | #
106 | # def main():
107 | # # Parse command-line arguments
108 | # parser = ArgumentParser()
109 | #
110 | # root = "/gpu-data3/filby/MEAD/rendered/train/MEAD/videos"
111 | #
112 | # p = Pool(20)
113 | #
114 | # test_list = []
115 | # for file in os.listdir(root):
116 | # test_list.append(os.path.join(root,file))
117 | #
118 | # # print(test_list)
119 | # # extract_wav(test_list[0])
120 | # for _ in tqdm(p.imap_unordered(extract_wav, test_list), total=len(test_list)):
121 | # pass
122 | #
123 | #
124 | # main()
--------------------------------------------------------------------------------
/external/spectre/utils/extract_frames_and_audio.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os, sys
4 | import cv2
5 | import argparse
6 | from tqdm import tqdm
7 | from multiprocessing import Pool
8 |
9 | def video2sequence(video_path):
10 | videofolder = os.path.splitext(video_path)[0]
11 | os.makedirs(videofolder, exist_ok=True)
12 | vidcap = cv2.VideoCapture(video_path)
13 | success,image = vidcap.read()
14 | count = 0
15 | imagepath_list = []
16 | while success:
17 | imagepath = os.path.join(videofolder, f'%06d.jpg'%count)
18 | cv2.imwrite(imagepath, image) # save frame as JPEG file
19 | success,image = vidcap.read()
20 | count += 1
21 | imagepath_list.append(imagepath)
22 | print('video frames are stored in {}'.format(videofolder))
23 | return videofolder
24 |
25 |
26 | def extract_audio(video_path):
27 | os.system("ffmpeg -i {} {} -y".format(video_path, video_path.replace(".mp4",".wav")))
28 |
29 |
30 | def main(args):
31 | video_list = []
32 |
33 | for mode in ["trainval","test"]:
34 | for folder in os.listdir(os.path.join(args.dataset_path,mode)):
35 | for file in os.listdir(os.path.join(args.dataset_path,mode,folder)):
36 | if file.endswith(".mp4"):
37 | video_list.append(os.path.join(args.dataset_path,mode,folder,file))
38 |
39 | p = Pool(12)
40 |
41 | for _ in tqdm(p.imap_unordered(video2sequence, video_list), total=len(video_list)):
42 | pass
43 |
44 | for _ in tqdm(p.imap_unordered(extract_audio, video_list), total=len(video_list)):
45 | pass
46 |
47 |
48 | if __name__ == '__main__':
49 | parser = argparse.ArgumentParser()
50 |
51 | parser.add_argument('--dataset_path', default='./data/LRS3', type=str, help='path to dataset')
52 | main(parser.parse_args())
--------------------------------------------------------------------------------
/external/spectre/utils/extract_wavs_LRS3.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import numpy as np
5 | import torch
6 | from argparse import ArgumentParser
7 |
8 | import sys
9 | sys.path.append("face_parsing")
10 |
11 |
12 | def extract_wav(videopath):
13 | print(videopath)
14 |
15 | os.system("ffmpeg -i {} {} -y".format(videopath, videopath.replace(".mp4",".wav")))
16 |
17 | from multiprocessing import Pool
18 | from tqdm import tqdm
19 |
20 | def main():
21 | # Parse command-line arguments
22 | parser = ArgumentParser()
23 |
24 | root = "/raid/gretsinas/LRS3/test"
25 |
26 | p = Pool(12)
27 |
28 | l = list(os.listdir("/raid/gretsinas/LRS3/test"))
29 | test_list = []
30 | for folder in l:
31 | for file in os.listdir(os.path.join("/raid/gretsinas/LRS3/test",folder)):
32 |
33 | if file.endswith(".txt"):
34 | test_list.append(os.path.join("/raid/gretsinas/LRS3/test",folder,file.replace(".txt",".mp4")))
35 |
36 | # print(test_list)
37 | # extract_wav(test_list[0])
38 | for _ in tqdm(p.imap_unordered(extract_wav, test_list), total=len(test_list)):
39 | pass
40 |
41 |
42 | main()
43 |
44 | # import os
45 | # import cv2
46 | # import time
47 | # import numpy as np
48 | # import torch
49 | # from argparse import ArgumentParser
50 | #
51 | # import sys
52 | # sys.path.append("face_parsing")
53 | #
54 | #
55 | # def extract_wav(videopath):
56 | # # print(videopath)
57 | #
58 | # os.system("ffmpeg -i {} {} -y".format(videopath, videopath.replace("/videos/","/wavs/").replace(".mp4",".wav")))
59 | #
60 | # from multiprocessing import Pool
61 | # from tqdm import tqdm
62 | #
63 | # def main():
64 | # # Parse command-line arguments
65 | # parser = ArgumentParser()
66 | #
67 | # root = "/gpu-data3/filby/MEAD/rendered/train/MEAD/videos"
68 | #
69 | # p = Pool(20)
70 | #
71 | # test_list = []
72 | # for file in os.listdir(root):
73 | # test_list.append(os.path.join(root,file))
74 | #
75 | # # print(test_list)
76 | # # extract_wav(test_list[0])
77 | # for _ in tqdm(p.imap_unordered(extract_wav, test_list), total=len(test_list)):
78 | # pass
79 | #
80 | #
81 | # main()
--------------------------------------------------------------------------------
/external/spectre/utils/lipread_utils.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import torch
6 | from phonemizer.backend import EspeakBackend
7 | from phonemizer.separator import Separator
8 | separator = Separator(phone='-', word=' ')
9 | backend = EspeakBackend('en-us', words_mismatch='ignore', with_stress=False)
10 | import cv2
11 |
12 | # phonemes to visemes map. this was created using Amazon Polly
13 | # https://docs.aws.amazon.com/polly/latest/dg/polly-dg.pdf
14 |
15 | def get_phoneme_to_viseme_map():
16 | pho2vi = {}
17 | # pho2vi_counts = {}
18 | all_vis = []
19 |
20 | p2v = "data/phonemes2visemes.csv"
21 |
22 | with open(p2v) as file:
23 | lines = file.readlines()
24 | # for line in lines[2:29]+lines[30:50]:
25 | for line in lines:
26 | if line.split(",")[0] in pho2vi:
27 | if line.split(",")[4].strip() != pho2vi[line.split(",")[0]]:
28 | print('error')
29 | pho2vi[line.split(",")[0]] = line.split(",")[4].strip()
30 |
31 | all_vis.append(line.split(",")[4].strip())
32 | # pho2vi_counts[line.split(",")[0]] = 0
33 | return pho2vi, all_vis
34 |
35 | pho2vi, all_vis = get_phoneme_to_viseme_map()
36 |
37 | def convert_text_to_visemes(text):
38 | phonemized = backend.phonemize([text], separator=separator)[0]
39 |
40 | text = ""
41 | for word in phonemized.split(" "):
42 | visemized = []
43 | for phoneme in word.split("-"):
44 | if phoneme == "":
45 | continue
46 | try:
47 | visemized.append(pho2vi[phoneme.strip()])
48 | if pho2vi[phoneme.strip()] not in all_vis:
49 | all_vis.append(pho2vi[phoneme.strip()])
50 | # pho2vi_counts[phoneme.strip()] += 1
51 | except:
52 | print('Count not find', phoneme)
53 | continue
54 | text += " " + "".join(visemized)
55 | return text
56 |
57 |
58 |
59 | def save2avi(filename, data=None, fps=25):
60 | """save2avi. - function taken from Visual Speech Recognition repository
61 |
62 | :param filename: str, the filename to save the video (.avi).
63 | :param data: numpy.ndarray, the data to be saved.
64 | :param fps: the chosen frames per second.
65 | """
66 | assert data is not None, "data is {}".format(data)
67 | os.makedirs(os.path.dirname(filename), exist_ok=True)
68 | fourcc = cv2.VideoWriter_fourcc("F", "F", "V", "1")
69 | writer = cv2.VideoWriter(filename, fourcc, fps, (data[0].shape[1], data[0].shape[0]), 0)
70 | for frame in data:
71 | writer.write(frame)
72 | writer.release()
73 |
74 |
75 | def predict_text(lipreader, mouth_sequence):
76 | from external.Visual_Speech_Recognition_for_Multiple_Languages.espnet.asr.asr_utils import add_results_to_json
77 | lipreader.model.eval()
78 | with torch.no_grad():
79 | enc_feats, _ = lipreader.model.encoder(mouth_sequence, None)
80 | enc_feats = enc_feats.squeeze(0)
81 |
82 | nbest_hyps = lipreader.beam_search(
83 | x=enc_feats,
84 | maxlenratio=lipreader.maxlenratio,
85 | minlenratio=lipreader.minlenratio
86 | )
87 | nbest_hyps = [
88 | h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), lipreader.nbest)]
89 | ]
90 |
91 | transcription = add_results_to_json(nbest_hyps, lipreader.char_list)
92 |
93 | return transcription.replace("