├── .gitignore ├── LICENSE ├── README.md ├── base ├── __init__.py ├── baseTrainer.py ├── base_model.py ├── config.py └── utilities.py ├── comparisons.png ├── config └── HDTF │ └── config.yaml ├── data └── dataloader_HDTF.py ├── demo.py ├── external └── spectre │ ├── .gitignore │ ├── .gitmodules │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── config.py │ ├── configs │ └── lipread_config.ini │ ├── datasets │ ├── __init__.py │ ├── build_datasets.py │ ├── data_utils.py │ ├── datasets.py │ └── extra_datasets.py │ ├── demo.py │ ├── get_training_data.sh │ ├── main.py │ ├── quick_install.sh │ ├── render.py │ ├── requirements.txt │ ├── src │ ├── __init__.py │ ├── models │ │ ├── FLAME.py │ │ ├── encoders.py │ │ ├── expression_loss.py │ │ ├── lbs.py │ │ └── resnet.py │ ├── spectre.py │ ├── trainer_spectre.py │ └── utils │ │ ├── lossfunc.py │ │ ├── renderer.py │ │ ├── rotation_converter.py │ │ ├── tensor_cropper.py │ │ ├── trainer.py │ │ └── util.py │ ├── utils │ ├── __init__.py │ ├── extract_frames_LRS3.py │ ├── extract_frames_and_audio.py │ ├── extract_wavs_LRS3.py │ ├── lipread_utils.py │ └── run_av_hubert.py │ └── visual_mesh.py ├── framework.png ├── losses └── loss_collections.py ├── models ├── lib │ ├── base_models.py │ ├── grl_module.py │ ├── modules.py │ └── wav2vec.py └── network.py ├── requirements.txt ├── tools └── render_spectre.py ├── train.py └── utils └── render_pyrender.py /.gitignore: -------------------------------------------------------------------------------- 1 | # python file # 2 | ############ 3 | *.pyc 4 | __pycache__ 5 | # OS generated files # 6 | ###################### 7 | .DS_Store 8 | .DS_Store? 9 | ._* 10 | .Spotlight-V100 11 | .Trashes 12 | ehthumbs.db 13 | Thumbs.db 14 | .vscode 15 | 16 | # Packages # 17 | ############ 18 | # it's better to unpack these files and commit the raw source 19 | # git has its own built in compression methods 20 | *.7z 21 | *.dmg 22 | *.gz 23 | *.iso 24 | *.jar 25 | *.rar 26 | *.tar 27 | *.zip 28 | *.pth 29 | 30 | demos 31 | pretrained 32 | external/spectre/pretrained 33 | external/spectre/data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DoubleXING 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## **Mimic** 2 | 3 | Official PyTorch implementation for the paper: 4 | 5 | > **Mimic: Speaking Style Disentanglement for Speech-Driven 3D Facial Animation**, ***AAAI 2024***. 6 | > 7 | > Hui Fu, Zeqing Wang, Ke Gong, Keze Wang, Tianshui Chen, Haojie Li, Haifeng Zeng, Wenxiong Kang 8 | > 9 | > 10 | 11 |

12 | 13 |

14 | 15 | >Speech-driven 3D facial animation aims to synthesize vivid facial animations that accurately synchronize with speech and match the unique speaking style. However, existing works primarily focus on achieving precise lip synchronization while neglecting to model the subject-specific speaking style, often resulting in unrealistic facial animations. To the best of our knowledge, this work makes the first attempt to explore the coupled information between the speaking style and the semantic content in facial motions. Specifically, we introduce an innovative speaking style disentanglement method, which enables arbitrary-subject speaking style encoding and leads to a more realistic synthesis of speech-driven facial animations. Subsequently, we propose a novel framework called **Mimic** to learn disentangled representations of the speaking style and content from facial motions by building two latent spaces for style and content, respectively. Moreover, to facilitate disentangled representation learning, we introduce four well-designed constraints: an auxiliary style classifier, an auxiliary inverse classifier, a content contrastive loss, and a pair of latent cycle losses, which can effectively contribute to the construction of the identity-related style space and semantic-related content space. Extensive qualitative and quantitative experiments conducted on three publicly available datasets demonstrate that our approach outperforms state-of-the-art methods and is capable of capturing diverse speaking styles for speech-driven 3D facial animation. 16 | 17 |

18 | 19 |

20 | 21 | ## **TODO** 22 | - ~~Release codes and weights for inference.~~ 23 | - ~~Release 3D-HDTF dataset.~~ 24 | - ~~Release codes for training.~~ 25 | 26 | ## **Environment** 27 | - Ubuntu 28 | - RTX 4090 29 | - CUDA 11.6 (GPU with at least 24GB VRAM) 30 | - Python 3.9 31 | ## **Dependencies** 32 | - PyTorch 1.13.1 33 | - ffmpeg 34 | - [PyTorch3D](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) (recommend) or [MPI-IS/mesh](https://github.com/MPI-IS/mesh) for rendering 35 | 36 | Other necessary packages: 37 | ``` 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ## **Demo** 42 | We provide some demos for 3D-HDTF. Please follow the process to run the demos. 43 | 44 | 1) Prepare data and pretrained models 45 | 46 | Clone the repository using: `git clone https://github.com/huifu99/Mimic.git` . 47 | 48 | Download the 3D-HDTF [data](https://drive.google.com/drive/folders/1s9FQQRAA_tqf58ThmWq3EWmFnV7syhhl?usp=drive_link) for demos and [model](https://drive.google.com/drive/folders/122oSYMiwQyzg8kvfWhp6A-JejMZGCNM8?usp=drive_link) trained using 3D-HDTF. Then put them to the root directory of Mimic. 49 | 50 | Prepare the [SPECTRE model trained on HDTF](https://drive.google.com/drive/folders/18YM4J4u5Tpi-JLQLh-_UwDUwSPoVZBb1?usp=drive_link) and [dependencies](https://drive.google.com/drive/folders/197z4B8GYZ9QFwzGFXBkIZtddHESlgg1K?usp=drive_link) of SPECTRE and put it to `external/spectre/`. 51 | Organize the files into the following structure: 52 | ``` 53 | Mimic 54 | │ 55 | └─── demos 56 | └─── wav 57 | └─── style_ref 58 | │ 59 | └─── pretrained 60 | └─── 61 | └─── Epoch_x.pth 62 | │ 63 | └─── external 64 | └───spectre 65 | └─── data 66 | └─── pretrained 67 | └─── HDTF_pretrained 68 | └─── ... 69 | │ 70 | └─── models 71 | │ 72 | └─── ... 73 | ``` 74 | 75 | 2) Run demos 76 | 77 | Run the following command to get the demo results (.npy file for vertices and .mp4 for videos) in `demos/results`: 78 | 79 | ``` 80 | python demo.py --wav_file demos/wav/RD_Radio11_001.wav --style_ref id_002-RD_Radio11_001 81 | ``` 82 | 83 | Your can change the parameters such as `--wav_file` and `--style_ref` according to your path. The process of generating style reference file will be provided soon. 84 | 85 | ## **Training and Evaluation** 86 | ### 3D-HDTF 87 | 1) Data Preparation 88 | 89 | Download our processed [3D-HDTF](https://pan.baidu.com/s/1ReX25BlG27mEcm4hKDUObw?pwd=HDTF) (extraction code: HDTF) data and put it in `./3D-HDTF` or your own directory. 90 | 91 | 2) Training 92 | 93 | - Modify your data path (or other settings) in the config file: `./config/HDTF/config.yaml`. (optional) 94 | - Run the train script: `python train.py --config ./config/HDTF/config.yaml` 95 | 96 | 3) Evaluation 97 | 98 | ## **Acknowledgement** 99 | We heavily borrow the code from 100 | [CodeTalker](https://github.com/Doubiiu/CodeTalker), 101 | [VOCA](https://github.com/TimoBolkart/voca) and [SPECTRE](https://github.com/filby89/spectre). Thanks 102 | for sharing their code. Our 3D-HDTF dataset is based on [HDTF](https://github.com/MRzzm/HDTF). Third-party packages are owned by their respective authors and must be used under their respective licenses. 103 | 104 | ## **Citation** 105 | 106 | If you find the code useful for your work, please star this repo and consider citing: 107 | 108 | ``` 109 | @inproceedings{hui2024Mimic, 110 | title={Mimic: Speaking Style Disentanglement for Speech-Driven 3D Facial Animation}, 111 | author={Hui Fu, Zeqing Wang, Ke Gong, Keze Wang, Tianshui Chen, Haojie Li, Haifeng Zeng, Wenxiong Kang}, 112 | booktitle={The 38th Annual AAAI Conference on Artificial Intelligence (AAAI)}, 113 | year={2024} 114 | } 115 | ``` -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import * -------------------------------------------------------------------------------- /base/baseTrainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | from os.path import join 4 | import torch.distributed as dist 5 | from .utilities import check_makedirs 6 | from collections import OrderedDict 7 | from torch.nn.parallel import DataParallel, DistributedDataParallel 8 | 9 | 10 | def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1): 11 | lr = base_lr * (multiplier ** (epoch // step_epoch)) 12 | return lr 13 | 14 | 15 | def poly_learning_rate(base_lr, curr_iter, max_iter, power=0.9): 16 | """poly learning rate policy""" 17 | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power 18 | return lr 19 | 20 | 21 | def adjust_learning_rate(optimizer, lr): 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | 25 | 26 | def save_checkpoint(model, other_state={}, sav_path='', filename='model.pth.tar', stage=1): 27 | if isinstance(model, (DistributedDataParallel, DataParallel)): 28 | weight = model.module.state_dict() 29 | elif isinstance(model, torch.nn.Module): 30 | weight = model.state_dict() 31 | else: 32 | raise ValueError('model must be nn.Module or nn.DataParallel!') 33 | check_makedirs(sav_path) 34 | 35 | if stage == 2: # remove vqvae part 36 | for k in list(weight.keys()): 37 | if 'autoencoder' in k: 38 | weight.pop(k) 39 | 40 | other_state['state_dict'] = weight 41 | filename = join(sav_path, filename) 42 | torch.save(other_state, filename) 43 | 44 | 45 | 46 | def load_state_dict(model, state_dict, strict=True): 47 | if isinstance(model, (DistributedDataParallel, DataParallel)): 48 | model.module.load_state_dict(state_dict, strict=strict) 49 | else: 50 | model.load_state_dict(state_dict, strict=strict) 51 | 52 | 53 | def state_dict_remove_module(state_dict): 54 | new_state_dict = OrderedDict() 55 | for k, v in state_dict.items(): 56 | # name = k[7:] # remove 'module.' of dataparallel 57 | name = k.replace('module.', '') 58 | new_state_dict[name] = v 59 | return new_state_dict 60 | 61 | 62 | def reduce_tensor(tensor, args): 63 | rt = tensor.clone() 64 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 65 | rt /= args.world_size 66 | return rt 67 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | 4 | class BaseModel(nn.Module): 5 | """ 6 | Base class for all models 7 | """ 8 | 9 | def __init__(self): 10 | super(BaseModel, self).__init__() 11 | # self.logger = logging.getLogger(self.__class__.__name__) 12 | 13 | def forward(self, *x): 14 | """ 15 | Forward pass logic 16 | 17 | :return: Model output 18 | """ 19 | raise NotImplementedError 20 | 21 | def summary(self, logger, writer): 22 | """ 23 | Model summary 24 | """ 25 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 26 | params = sum([np.prod(p.size()) for p in model_parameters]) / 1e6 # Unit is Mega 27 | logger.info(self) 28 | logger.info('===>Trainable parameters: %.3f M' % params) 29 | if writer is not None: 30 | writer.add_text('Model Summary', 'Trainable parameters: %.3f M' % params) 31 | -------------------------------------------------------------------------------- /base/config.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Functions for parsing args 3 | # ----------------------------------------------------------------------------- 4 | import yaml 5 | import os 6 | from ast import literal_eval 7 | import copy 8 | 9 | 10 | class CfgNode(dict): 11 | """ 12 | CfgNode represents an internal node in the configuration tree. It's a simple 13 | dict-like container that allows for attribute-based access to keys. 14 | """ 15 | 16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 17 | # Recursively convert nested dictionaries in init_dict into CfgNodes 18 | init_dict = {} if init_dict is None else init_dict 19 | key_list = [] if key_list is None else key_list 20 | for k, v in init_dict.items(): 21 | if type(v) is dict: 22 | # Convert dict to CfgNode 23 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 24 | super(CfgNode, self).__init__(init_dict) 25 | 26 | def __getattr__(self, name): 27 | if name in self: 28 | return self[name] 29 | else: 30 | raise AttributeError(name) 31 | 32 | def __setattr__(self, name, value): 33 | self[name] = value 34 | 35 | def __str__(self): 36 | def _indent(s_, num_spaces): 37 | s = s_.split("\n") 38 | if len(s) == 1: 39 | return s_ 40 | first = s.pop(0) 41 | s = [(num_spaces * " ") + line for line in s] 42 | s = "\n".join(s) 43 | s = first + "\n" + s 44 | return s 45 | 46 | r = "" 47 | s = [] 48 | for k, v in sorted(self.items()): 49 | seperator = "\n" if isinstance(v, CfgNode) else " " 50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 51 | attr_str = _indent(attr_str, 2) 52 | s.append(attr_str) 53 | r += "\n".join(s) 54 | return r 55 | 56 | def __repr__(self): 57 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 58 | 59 | 60 | def load_cfg_from_cfg_file(file): 61 | cfg = {} 62 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 63 | '{} is not a yaml file'.format(file) 64 | 65 | with open(file, 'r') as f: 66 | cfg_from_file = yaml.safe_load(f) 67 | 68 | for key in cfg_from_file: 69 | for k, v in cfg_from_file[key].items(): 70 | cfg[k] = v 71 | 72 | cfg = CfgNode(cfg) 73 | return cfg 74 | 75 | 76 | def merge_cfg_from_list(cfg, cfg_list): 77 | new_cfg = copy.deepcopy(cfg) 78 | assert len(cfg_list) % 2 == 0 79 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 80 | subkey = full_key.split('.')[-1] 81 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key) 82 | value = _decode_cfg_value(v) 83 | value = _check_and_coerce_cfg_value_type( 84 | value, cfg[subkey], subkey, full_key 85 | ) 86 | setattr(new_cfg, subkey, value) 87 | 88 | return new_cfg 89 | 90 | 91 | def _decode_cfg_value(v): 92 | """Decodes a raw config value (e.g., from a yaml config files or command 93 | line argument) into a Python object. 94 | """ 95 | # All remaining processing is only applied to strings 96 | if not isinstance(v, str): 97 | return v 98 | # Try to interpret `v` as a: 99 | # string, number, tuple, list, dict, boolean, or None 100 | try: 101 | v = literal_eval(v) 102 | # The following two excepts allow v to pass through when it represents a 103 | # string. 104 | # 105 | # Longer explanation: 106 | # The type of v is always a string (before calling literal_eval), but 107 | # sometimes it *represents* a string and other times a data structure, like 108 | # a list. In the case that v represents a string, what we got back from the 109 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 110 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 111 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 112 | # will raise a SyntaxError. 113 | except ValueError: 114 | pass 115 | except SyntaxError: 116 | pass 117 | return v 118 | 119 | 120 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 121 | """Checks that `replacement`, which is intended to replace `original` is of 122 | the right type. The type is correct if it matches exactly or is one of a few 123 | cases in which the type can be easily coerced. 124 | """ 125 | original_type = type(original) 126 | replacement_type = type(replacement) 127 | 128 | # The types must match (with some exceptions) 129 | if replacement_type == original_type or original is None: 130 | return replacement 131 | 132 | # Cast replacement from from_type to to_type if the replacement and original 133 | # types match from_type and to_type 134 | def conditional_cast(from_type, to_type): 135 | if replacement_type == from_type and original_type == to_type: 136 | return True, to_type(replacement) 137 | else: 138 | return False, None 139 | 140 | # Conditionally casts 141 | # list <-> tuple 142 | casts = [(tuple, list), (list, tuple)] 143 | # For py2: allow converting from str (bytes) to a unicode string 144 | try: 145 | casts.append((str, unicode)) # noqa: F821 146 | except Exception: 147 | pass 148 | 149 | for (from_type, to_type) in casts: 150 | converted, converted_value = conditional_cast(from_type, to_type) 151 | if converted: 152 | return converted_value 153 | 154 | raise ValueError( 155 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 156 | "key: {}".format( 157 | original_type, replacement_type, original, replacement, full_key 158 | ) 159 | ) 160 | 161 | 162 | def _assert_with_logging(cond, msg): 163 | if not cond: 164 | logger.debug(msg) 165 | assert cond, msg 166 | -------------------------------------------------------------------------------- /base/utilities.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | import random 5 | import time 6 | import logging 7 | import yaml 8 | import numpy as np 9 | import torch 10 | from base import config 11 | 12 | 13 | def yaml2json(yaml_dir): 14 | f = open(yaml_dir, 'r') 15 | ystr = f.read() 16 | params = yaml.load(ystr, Loader=yaml.FullLoader) 17 | info_dict={} 18 | info_dict['params']=params 19 | return info_dict 20 | 21 | 22 | def get_parser(): 23 | parser = argparse.ArgumentParser(description=' ') 24 | parser.add_argument('--config', type=str, default='./config/HDTF/config.yaml', help='config file') 25 | parser.add_argument('opts', help=' ', default=None, 26 | nargs=argparse.REMAINDER) 27 | args = parser.parse_args() 28 | assert args.config is not None 29 | cfg = config.load_cfg_from_cfg_file(args.config) 30 | if args.opts is not None: 31 | cfg = config.merge_cfg_from_list(cfg, args.opts) 32 | params_dict = yaml2json(args.config) 33 | return cfg, params_dict 34 | 35 | 36 | def get_logger(): 37 | logger_name = "main-logger" 38 | logger = logging.getLogger(logger_name) 39 | logger.setLevel(logging.INFO) 40 | handler = logging.StreamHandler() 41 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d]=>%(message)s" 42 | handler.setFormatter(logging.Formatter(fmt)) 43 | logger.addHandler(handler) 44 | return logger 45 | 46 | 47 | class AverageMeter(object): 48 | """Computes and stores the average and current value""" 49 | 50 | def __init__(self): 51 | self.reset() 52 | 53 | def reset(self): 54 | self.val = 0 55 | self.avg = 0 56 | self.sum = 0 57 | self.count = 0 58 | 59 | def update(self, val, n=1): 60 | self.val = val 61 | self.sum += val * n 62 | self.count += n 63 | self.avg = self.sum / self.count 64 | 65 | 66 | def check_mkdir(dir_name): 67 | if not os.path.exists(dir_name): 68 | os.mkdir(dir_name) 69 | 70 | 71 | def check_makedirs(dir_name): 72 | if not os.path.exists(dir_name): 73 | os.makedirs(dir_name) 74 | 75 | 76 | def main_process(args): 77 | return not args.multiprocessing_distributed or ( 78 | args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0) 79 | 80 | 81 | def fixed_seed(seed=1): 82 | # seed = 1 83 | random.seed(seed) 84 | np.random.seed(seed) 85 | torch.manual_seed(seed) 86 | torch.cuda.manual_seed(seed) 87 | torch.cuda.manual_seed_all(seed) 88 | torch.backends.cudnn.deterministic = True 89 | torch.backends.cudnn.benchmark = False -------------------------------------------------------------------------------- /comparisons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/comparisons.png -------------------------------------------------------------------------------- /config/HDTF/config.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | dataset: HDTF 3 | data_root: ./3D-HDTF # your data root of 3D-HDTF 4 | audio_path: audio_16000Hz_from_video 5 | text_path: sentencestext 6 | codedict_path: spectre_processed_25fps_16kHz 7 | video_list_file: video_list_id_framenum-train_test.txt 8 | template_file: templates.pkl 9 | FLAME_templates: FLAME_templates/FLAME_sample.ply 10 | FLAME_template_HDTF: FLAME_HDTF_sample.npy 11 | landmark_embedding: external/spectre/data/landmark_embedding.npy 12 | lip_verts: FLAME_Regions/lve.txt 13 | id_num_all: 220 # 220 14 | train_ids: 150 # 150/220 15 | val_ids: 10 16 | val_video_list: video_list_id_framenum-val.txt 17 | video_fps: 25 18 | audio_fps: 50 19 | vocab_size: 32 20 | 21 | sentence_num: 40 # 50 22 | window: 150 23 | clip_overlap: True 24 | num_workers: 4 25 | 26 | NETWORK: 27 | motion_dim: 15069 28 | feature_dim: 128 29 | num_conv_layers: 5 30 | hidden_size: 128 31 | nhead_encoder: 4 32 | dim_feedforward: 256 33 | encoder_layers: 4 34 | style_pooling: 'mean' # mean/max 35 | style_pretrain: 36 | freeze_style_encoder: False 37 | content_norm: 'IN' # LN/IN 38 | content_attention: True 39 | wav2vec2model: facebook/wav2vec2-base-960h 40 | audio_hidden_size: 768 41 | freeze_TCN: True 42 | freeze_audio_encoder: False 43 | audio_feature_align: 'conv' # interpolation/conv 44 | style_fuse: 'SALN' # add/cat/adain/SALN 45 | 46 | decoder: 47 | nhead: 4 48 | dim_feedforward: 256 49 | num_layers: 1 50 | 51 | decoder_NAR: 52 | max_seq_len: 600 53 | decoder_layers: 4 54 | decoder_hidden: 64 55 | decoder_head: 4 56 | fft_conv1d_filter_size: 1024 57 | fft_conv1d_kernel_size: [9, 1] 58 | 59 | TRAIN: 60 | output: ./experiments 61 | exp_name: exp-AR-dim128-align_conv-IN-style_cls-grl-clip_loss-cycle_style_content-style_SALN 62 | batch_size: 6 63 | base_lr: 0.0001 64 | lr_sch_gamma: 0.1 65 | lr_sch_epoch: # 30 66 | start_epoch: 0 67 | epochs: 150 68 | save_freq: 5 69 | print_freq: 100 70 | visual_tsne: True 71 | with_val: True 72 | val_epoch: 5 73 | continue_ckpt: 74 | 75 | LOSS: 76 | recon_loss: 77 | use: True 78 | w: 1 79 | content_code_sim_loss: 80 | use: False 81 | w: 1.0e-6 82 | content_contrastive_loss: 83 | use: False 84 | w: 1.0e-5 85 | margin: 1.0 86 | content_clip_loss: 87 | use: True 88 | w: 5.0e-7 89 | content_ctc_loss: 90 | use: False 91 | w: 1.0e-8 92 | content_grl_loss: 93 | use: True 94 | use_grl: True 95 | w: 5.0e-7 96 | alpha: 1.0 97 | w_decay: 20 # 40 98 | style_class_loss: 99 | use: True 100 | use_metrics: False 101 | w: 2.5e-7 # 1.0e-6/1.0e-7/5.0e-7 102 | content_class_loss: 103 | use: False 104 | w: 1.0e-7 105 | style_cycle_loss: 106 | use: True 107 | sim: 'cos' # cos/L1 108 | w: 2.0e-5 # 1.0e-4 109 | content_cycle_loss: 110 | use: True 111 | w: 5.0e-6 # 5.0e-6/1.0e-5 112 | 113 | DEMO: 114 | -------------------------------------------------------------------------------- /data/dataloader_HDTF.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['TRANSFORMERS_OFFLINE'] = '1' 3 | import torch 4 | import numpy as np 5 | import pickle 6 | import copy 7 | import random 8 | from tqdm import tqdm 9 | from transformers import Wav2Vec2Processor 10 | import librosa 11 | from collections import defaultdict 12 | from torch.utils import data 13 | 14 | class Dataset(data.Dataset): 15 | def __init__(self, cfg, data_type="train") -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | self.data_type = data_type 19 | self.fps = cfg.video_fps 20 | self.sentence_num = cfg.sentence_num 21 | self.window = cfg.window 22 | self.clip_overlap = self.cfg.clip_overlap 23 | self.data_root = cfg.data_root 24 | self.audio_path = os.path.join(self.data_root, cfg.audio_path) 25 | self.codedict_path = os.path.join(self.data_root, cfg.codedict_path) 26 | self.video_list_file = os.path.join(self.data_root, cfg.video_list_file) 27 | template_file = os.path.join(self.data_root, cfg.template_file) 28 | # with open(template_file, 'rb') as fin: 29 | self.templates = pickle.load(open(template_file, 'rb')) 30 | 31 | self.processor = Wav2Vec2Processor.from_pretrained(cfg.wav2vec2model) 32 | 33 | self.get_data() 34 | self.one_hot_init = np.eye(cfg.train_ids) 35 | if data_type == 'train': 36 | train_id_list = list(range(cfg.train_ids)) 37 | id_list = [id for id in self.id_list if id-1 in train_id_list] 38 | video_list = [] 39 | for i, id in enumerate(self.id_list): 40 | if id-1 in train_id_list: 41 | video_list.append(self.video_list[i]) 42 | self.id_list = id_list 43 | self.video_list = video_list 44 | 45 | self.id_set = sorted(list(set(self.id_list))) 46 | print('train_ids: ', self.id_set) 47 | 48 | def __len__(self): 49 | return len(self.video_list)*self.sentence_num 50 | 51 | def __getitem__(self, index): 52 | video_idx = index%len(self.video_list) 53 | video_label = self.video_list[video_idx] 54 | # subject_id = video_label.split('_')[1] 55 | subject_id = self.id_list[video_idx] 56 | # onehot 57 | # one_hot = self.one_hot_init[self.id_list.index(subject_id)] 58 | one_hot = self.one_hot_init[subject_id-1] 59 | 60 | # template 61 | template = self.templates[subject_id].reshape(-1) 62 | # vertices 63 | codedict_npy = os.path.join(self.codedict_path, video_label, 'verts_new_shape1.npy') 64 | codedict = np.load(codedict_npy, allow_pickle=True).item() 65 | vertices_all = codedict['verts'].reshape(-1, self.cfg.motion_dim) 66 | frame_num = vertices_all.shape[0] 67 | if self.clip_overlap: 68 | max_frame = frame_num-self.window-1 69 | start_frame = random.randint(0, max_frame) 70 | else: 71 | max_idx = frame_num//self.window-1 72 | start_idx = random.randint(0, max_idx) 73 | start_frame = start_idx*self.window 74 | vertices = vertices_all[start_frame: start_frame+self.window, :] 75 | # init_state 76 | init_state = np.zeros((vertices.shape[-1])) if start_frame==0 else vertices_all[start_frame-1, :] 77 | 78 | # audio 79 | # huggingface 80 | # wav_path = os.path.join(self.audio_path, video_label+'.wav') 81 | # speech_array, sampling_rate = librosa.load(wav_path, sr=16000) 82 | audio_npy = os.path.join(self.codedict_path, video_label, 'audio_librosa.npy') 83 | audio_read = np.load(audio_npy, allow_pickle=True).item() 84 | speech_array, sampling_rate = audio_read['speech_array'], audio_read['sampling_rate'] 85 | stride = int(sampling_rate//self.fps) 86 | audio = speech_array[start_frame*stride: (start_frame+self.window)*stride] 87 | audio = np.squeeze(self.processor(audio,sampling_rate=16000).input_values) 88 | if audio.shape[0]<96000: 89 | audio = np.concatenate([audio, np.zeros((96000-audio.shape[0]))], axis=0) 90 | 91 | return { 92 | 'audio': torch.FloatTensor(audio), 93 | 'vertices': torch.FloatTensor(vertices), 94 | 'template': torch.FloatTensor(template), 95 | 'one_hot': torch.FloatTensor(one_hot), 96 | 'subject_id': subject_id, 97 | 'init_state': torch.FloatTensor(init_state), 98 | } 99 | 100 | 101 | def get_data(self): 102 | self.video_dict = {} 103 | with open(self.video_list_file, 'r') as f: 104 | lines = f.readlines() 105 | for line in lines: 106 | if self.data_type=='train' and line.strip().split(' ')[-1] != 'test': 107 | video = line.strip().split(' ')[0] 108 | id = int(line.strip().split(' ')[1]) 109 | self.video_dict[video] = id 110 | self.video_list = list(self.video_dict.keys()) 111 | self.id_list = list(self.video_dict.values()) 112 | 113 | 114 | def get_dataloaders(cfg): 115 | dataset = {} 116 | train_data = Dataset(cfg,data_type="train") 117 | dataset["train"] = data.DataLoader(dataset=train_data, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers) 118 | # valid_data = Dataset(cfg,data_type="val") 119 | # dataset["valid"] = data.DataLoader(dataset=valid_data, batch_size=1, shuffle=False, num_workers=cfg.num_workers) 120 | return dataset 121 | 122 | 123 | if __name__ == '__main__': 124 | import sys 125 | sys.path.append('.') 126 | from base.utilities import get_parser 127 | cfg, params_dict = get_parser() 128 | dataset = get_dataloaders(cfg) 129 | train_loader = dataset['train'] 130 | for i, d in enumerate(train_loader): 131 | print(d) -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import pickle 5 | import shutil 6 | import random 7 | import librosa 8 | import torch 9 | from base.config import CfgNode 10 | from transformers import Wav2Vec2Processor 11 | from tools.render_spectre import Render 12 | 13 | from models.network import DisNetAutoregCycle as Model 14 | 15 | cv2.ocl.setUseOpenCL(False) 16 | cv2.setNumThreads(0) 17 | import tempfile 18 | from subprocess import call 19 | from psbody.mesh import Mesh 20 | from utils.render_pyrender import render_mesh_helper 21 | 22 | 23 | class Infer(): 24 | def __init__(self, ckpt, device='cuda:0') -> None: 25 | super().__init__() 26 | self.device = device 27 | ckpt_info = torch.load(ckpt, map_location='cpu') 28 | param_dict = ckpt_info['params'] 29 | cfg = {} 30 | for key in param_dict: 31 | for k, v in param_dict[key].items(): 32 | cfg[k] = v 33 | cfg = CfgNode(cfg) 34 | self.cfg = cfg 35 | 36 | self.fps = cfg.video_fps 37 | self.window = cfg.window 38 | self.motion_dim = cfg.motion_dim 39 | self.overlap = 10 40 | self.image_size = 224 41 | self.demo_path = 'demos' 42 | 43 | # load model 44 | weights = ckpt_info['model'] 45 | self.model = Model(cfg) 46 | self.model.load_state_dict(weights) 47 | self.model.eval().to(device) 48 | 49 | # audio processor 50 | self.processor = Wav2Vec2Processor.from_pretrained(cfg.wav2vec2model) 51 | 52 | # render from spectre 53 | self.render = Render(image_size=self.image_size, background=True, device=device) 54 | 55 | 56 | def infer_from_wav(self, wav_file, style_ref, save_path, use_pytorch3d=True): 57 | os.makedirs(save_path, exist_ok=True) 58 | 59 | # load audio-huggingface 60 | speech_array, sampling_rate = librosa.load(wav_file, sr=16000) 61 | stride = int(sampling_rate//self.fps) 62 | audio_frame_num = speech_array.shape[0]//stride 63 | 64 | # load style reference 65 | id = style_ref.split('-')[0] 66 | video_name = style_ref.split('-')[1] 67 | style_npy = os.path.join(self.demo_path, 'style_ref', id, video_name, 'verts_new_shape1.npy') 68 | codedict = np.load(style_npy, allow_pickle=True).item() 69 | vertices = codedict['verts'].reshape(-1, self.motion_dim) 70 | vertices = torch.FloatTensor(vertices) # (L, 15069) 71 | style_input_vertices = vertices[-self.window:].unsqueeze(0).to(self.device) 72 | 73 | # template 74 | template_pkl = os.path.join(self.demo_path, 'style_ref', id, style_ref+'.pkl') 75 | template = pickle.load(open(template_pkl, 'rb')) 76 | template = template.reshape(-1, self.motion_dim) 77 | template = torch.FloatTensor(template).to(self.device) 78 | 79 | # infer 80 | f = 0 81 | while f < audio_frame_num: 82 | if f == 0: 83 | init_state = torch.zeros((1, self.motion_dim)).to(self.device) 84 | temp_end_f = min(f+self.window, audio_frame_num) 85 | else: 86 | temp_end_f = min(f+self.window-self.overlap, audio_frame_num) 87 | temp_start_f = max(0, f-self.overlap) 88 | audio_batch = speech_array[temp_start_f*stride: temp_end_f*stride] 89 | audio_batch = np.squeeze(self.processor(audio_batch,sampling_rate=16000).input_values) 90 | audio_batch = torch.FloatTensor(audio_batch).to(self.device).unsqueeze(0) 91 | with torch.no_grad(): 92 | out_batch, _ = self.model.predict(audio_batch, style_input_vertices, init_state, template) 93 | 94 | init_state = out_batch[:, -self.overlap, :]-template 95 | if f == 0: 96 | pred = out_batch.cpu() 97 | else: 98 | pred = torch.cat([pred, out_batch[:, self.overlap:, :].cpu()], dim=1) 99 | f = temp_end_f 100 | pred = pred.squeeze(0).numpy() 101 | pred = np.reshape(pred,(-1,self.motion_dim//3,3)) 102 | 103 | # save verts 104 | wav_name = os.path.basename(wav_file).replace('.wav', '') 105 | file_name = "audio_{}-style_{}".format(wav_name, style_ref) 106 | save_npy = os.path.join(save_path, file_name+'.npy') 107 | os.makedirs(os.path.join(save_path), exist_ok=True) 108 | save_dict = { 109 | 'verts_pred': pred, 110 | } 111 | np.save(save_npy, save_dict) 112 | print('Saved {}'.format(save_npy)) 113 | 114 | # render 115 | num_frames = pred.shape[0] 116 | tmp_video_file = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir=save_path) 117 | writer = cv2.VideoWriter(tmp_video_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, (self.image_size, self.image_size), True) 118 | center = np.mean(pred[0], axis=0) 119 | for i_frame in range(num_frames): 120 | if use_pytorch3d: 121 | pred_img = self.render(pred[i_frame:i_frame+1]) 122 | pred_img = pred_img[0] 123 | else: 124 | render_mesh = Mesh(pred[i_frame], self.template_mesh.f) 125 | pred_img = render_mesh_helper(self.cfg, render_mesh, center) 126 | pred_img = pred_img.astype(np.uint8) 127 | writer.write(pred_img) 128 | writer.release() 129 | video_fname = os.path.join(save_path, file_name+'-no_audio.mp4') 130 | cmd = ('ffmpeg' + ' -i {0} -pix_fmt yuv420p -qscale 0 {1}'.format(tmp_video_file.name, video_fname)).split() 131 | call(cmd) 132 | # add audio 133 | cmd = ('ffmpeg' + ' -i {0} -i {1} -vcodec h264 -ac 2 -channel_layout stereo -qscale 0 {2}'.format(wav_file, video_fname, video_fname.replace('-no_audio.mp4', '.mp4'))).split() 134 | call(cmd) 135 | if os.path.exists(video_fname): 136 | os.remove(video_fname) 137 | 138 | 139 | if __name__ == '__main__': 140 | import argparse 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--wav_file', type=str, default='demos/wav/RD_Radio11_001.wav', help='audio input') 143 | parser.add_argument('--style_ref', type=str, default='id_002-RD_Radio11_001', help='style reference name') 144 | parser.add_argument('--checkpoint', type=str, default='pretrained/exp-AR-dim128-IN-style_cls-grl-clip_loss-cycle_style_content-style_SALN-no_sch-epoch_125/Epoch_125.pth', help='checkpoint for inference') 145 | parser.add_argument('--output_path', type=str, default='demos/results', help='output path') 146 | parser.add_argument('--use_pytorch3d', type=bool, default=True, help='whether to use PyTorch3D for rendering, if False, pyrender will be used') 147 | 148 | args = parser.parse_args() 149 | 150 | # infer 151 | infer = Infer(ckpt=args.checkpoint) 152 | infer.infer_from_wav(wav_file=args.wav_file, style_ref=args.style_ref, save_path=args.output_path) 153 | -------------------------------------------------------------------------------- /external/spectre/.gitignore: -------------------------------------------------------------------------------- 1 | # # Compiled source # 2 | # ################### 3 | # *.o 4 | # *.so 5 | 6 | # # Packages # 7 | # ############ 8 | # # it's better to unpack these files and commit the raw source 9 | # # git has its own built in compression methods 10 | # *.7z 11 | # *.dmg 12 | # *.gz 13 | # *.iso 14 | # *.jar 15 | # *.rar 16 | # *.tar 17 | # *.zip 18 | 19 | # # OS generated files # 20 | # ###################### 21 | # .DS_Store 22 | # .DS_Store? 23 | # ._* 24 | # .Spotlight-V100 25 | # .Trashes 26 | # ehthumbs.db 27 | # Thumbs.db 28 | # .vscode 29 | 30 | # # 3D data # 31 | # ############ 32 | # *.mat 33 | # *.pkl 34 | # *.obj 35 | # *.dat 36 | # *.npz 37 | 38 | # # python file # 39 | # ############ 40 | # *.pyc 41 | # __pycache__ 42 | 43 | # ## deca data 44 | # *results* 45 | # # *_vis.jpg 46 | # ## internal use 47 | # cluster_scripts 48 | # internal 49 | # data/FLAME2020 50 | # data/FLAMETexture 51 | # misc/ 52 | # data/LRS3_V_WER32.3/ 53 | # data/MEAD_test_text.pth 54 | # data/MEAD_train_text.pth 55 | # data/ResNet50/ 56 | # logs/ 57 | -------------------------------------------------------------------------------- /external/spectre/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/face_alignment"] 2 | path = external/face_alignment 3 | url = https://github.com/hhj1897/face_alignment.git 4 | [submodule "external/av_hubert"] 5 | path = external/av_hubert 6 | url = https://github.com/facebookresearch/av_hubert 7 | [submodule "external/face_detection"] 8 | path = external/face_detection 9 | url = https://github.com/hhj1897/face_detection 10 | [submodule "external/Visual_Speech_Recognition_for_Multiple_Languages"] 11 | path = external/Visual_Speech_Recognition_for_Multiple_Languages 12 | url = https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages 13 | -------------------------------------------------------------------------------- /external/spectre/README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # SPECTRE: Visual Speech-Aware Perceptual 3D Facial Expression Reconstruction from Videos 4 | 5 | [![Paper](https://img.shields.io/badge/arXiv-2207.11094-brightgreen)](https://arxiv.org/abs/2207.11094) 6 |   [![Project WebPage](https://img.shields.io/badge/Project-webpage-blue)](https://filby89.github.io/spectre/) 7 |   8 | Youtube Video 9 | 10 |
11 | 12 |

13 | 14 | 15 |

16 | 17 | 18 |

19 | 20 |

21 |

Our method performs visual-speech aware 3D reconstruction so that speech perception from the original footage is preserved in the reconstructed talking head. On the left we include the word/phrase being said for each example.

22 | 23 | This is the official Pytorch implementation of the paper: 24 | 25 | ``` 26 | Visual Speech-Aware Perceptual 3D Facial Expression Reconstruction from Videos 27 | Panagiotis P. Filntisis, George Retsinas, Foivos Paraperas-Papantoniou, Athanasios Katsamanis, Anastasios Roussos, and Petros Maragos 28 | arXiv 2022 29 | ``` 30 | 31 | 32 | 33 | ## Installation 34 | Clone the repo and its submodules: 35 | ```bash 36 | git clone --recurse-submodules -j4 https://github.com/filby89/spectre 37 | cd spectre 38 | ``` 39 | 40 | You need to have installed a working version of Pytorch with Python 3.6 or higher and Pytorch 3D. You can use the following commands to create a working installation: 41 | ```bash 42 | conda create -n "spectre" python=3.8 43 | conda install -c pytorch pytorch=1.11.0 torchvision torchaudio # you might need to select cudatoolkit version here by adding e.g. cudatoolkit=11.3 44 | conda install -c conda-forge -c fvcore fvcore iopath 45 | conda install pytorch3d -c pytorch3d 46 | pip install -r requirements.txt # install the rest of the requirements 47 | ``` 48 | 49 | Installing a working setup of Pytorch3d with Pytorch can be a bit tricky. For development we used Pytorch3d 0.6.1 with Pytorch 1.10.0. 50 | 51 | PyTorch3d 0.6.2 with pytorch 1.11.0 are also compatible. 52 | 53 | Install the face_alignment and face_detection packages: 54 | ```bash 55 | cd external/face_alignment 56 | pip install -e . 57 | cd ../face_detection 58 | git lfs pull 59 | pip install -e . 60 | cd ../.. 61 | ``` 62 | You may need to install git-lfs to run the above commands. [More details](https://stackoverflow.com/questions/48734119/git-lfs-is-not-a-git-command-unclear) 63 | ```bash 64 | curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash 65 | sudo apt-get install git-lfs 66 | ``` 67 | Download the FLAME model and the pretrained SPECTRE model: 68 | ```bash 69 | pip install gdown 70 | bash quick_install.sh 71 | ``` 72 | 73 | ## Demo 74 | Samples are included in ``samples`` folder. You can run the demo by running 75 | 76 | ```bash 77 | python demo.py --input samples/LRS3/0Fi83BHQsMA_00002.mp4 --audio 78 | ``` 79 | 80 | The audio flag extracts audio from the input video and puts it in the output shape video for visualization purposes (ffmpeg is required for video creation). 81 | 82 | ## Training and Testing 83 | In order to train the model you need to download the `trainval` and `test` sets of the [LRS3 dataset](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/lrs3.html). After downloading 84 | the dataset, run the following command to extract frames and audio from the videos (audio is not needed for training but it is nice for visualizing the result): 85 | 86 | ```bash 87 | python utils/extract_frames_and_audio.py --dataset_path ./data/LRS3 88 | ``` 89 | 90 | After downloading and preprocessing the dataset, download the rest needed assets: 91 | 92 | ```bash 93 | bash get_training_data.sh 94 | ``` 95 | 96 | This command downloads the original [DECA](https://github.com/YadiraF/DECA/) pretrained model, 97 | the ResNet50 emotion recognition model provided by [EMOCA](https://github.com/radekd91/emoca), 98 | the pretrained lipreading model and detected landmarks for the videos of the LRS3 dataset provided by [Visual_Speech_Recognition_for_Multiple_Languages](https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages). 99 | 100 | Finally, you need to create a texture model using the repository [BFM_to_FLAME](https://github.com/TimoBolkart/BFM_to_FLAME#create-texture-model). Due 101 | to licencing reasons we are not allowed to share it to you. 102 | 103 | Now, you can run the following command to train the model: 104 | 105 | ```bash 106 | python main.py --output_dir logs --landmark 50 --relative_landmark 25 --lipread 2 --expression 0.5 --epochs 6 --LRS3_path data/LRS3 --LRS3_landmarks_path data/LRS3_landmarks 107 | ``` 108 | 109 | and then test it on the LRS3 dataset test set: 110 | 111 | ```bash 112 | python main.py --test --output_dir logs --model_path logs/model.tar --LRS3_path data/LRS3 --LRS3_landmarks_path data/LRS3_landmarks 113 | ``` 114 | 115 | and run lipreading with AV-hubert: 116 | 117 | ```bash 118 | # and run lipreading with our script 119 | python utils/run_av_hubert.py --videos "logs/test_videos_000000/*_mouth.avi --LRS3_path data/LRS3" 120 | ``` 121 | 122 | 123 | ## Acknowledgements 124 | This repo is has been heavily based on the original implementation of [DECA](https://github.com/YadiraF/DECA/). We also acknowledge the following 125 | repositories which we have benefited greatly from as well: 126 | 127 | - [EMOCA](https://github.com/radekd91/emoca) 128 | - [face_alignment](https://github.com/hhj1897/face_alignment) 129 | - [face_detection](https://github.com/hhj1897/face_detection) 130 | - [Visual_Speech_Recognition_for_Multiple_Languages](https://github.com/mpc001/Visual_Speech_Recognition_for_Multiple_Languages) 131 | 132 | ## Citation 133 | If your research benefits from this repository, consider citing the following: 134 | 135 | ``` 136 | @misc{filntisis2022visual, 137 | title = {Visual Speech-Aware Perceptual 3D Facial Expression Reconstruction from Videos}, 138 | author = {Filntisis, Panagiotis P. and Retsinas, George and Paraperas-Papantoniou, Foivos and Katsamanis, Athanasios and Roussos, Anastasios and Maragos, Petros}, 139 | publisher = {arXiv}, 140 | year = {2022}, 141 | } 142 | ``` 143 | 144 | 145 | -------------------------------------------------------------------------------- /external/spectre/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/external/spectre/__init__.py -------------------------------------------------------------------------------- /external/spectre/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Default config for SPECTRE - adapted from DECA 3 | ''' 4 | from yacs.config import CfgNode as CN 5 | import argparse 6 | import yaml 7 | import os 8 | 9 | cfg = CN() 10 | 11 | cfg.project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src', '..')) 12 | cfg.device = 'cuda' 13 | cfg.device_ids = '0' 14 | 15 | cfg.pretrained_modelpath = os.path.join(cfg.project_dir, 'data', 'deca_model.tar') 16 | cfg.output_dir = '' 17 | cfg.rasterizer_type = 'pytorch3d' 18 | # ---------------------------------------------------------------------------- # 19 | # Options for FLAME and from original DECA 20 | # ---------------------------------------------------------------------------- # 21 | cfg.model = CN() 22 | cfg.model.topology_path = os.path.join(cfg.project_dir, 'data' , 'head_template.obj') 23 | # texture data original from http://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_texture_data.zip 24 | cfg.model.dense_template_path = os.path.join(cfg.project_dir, 'data', 'texture_data_256.npy') 25 | cfg.model.fixed_displacement_path = os.path.join(cfg.project_dir, 'data', 'fixed_displacement_256.npy') 26 | cfg.model.flame_model_path = os.path.join(cfg.project_dir, 'data', 'FLAME2020', 'generic_model.pkl') 27 | cfg.model.flame_lmk_embedding_path = os.path.join(cfg.project_dir, 'data', 'landmark_embedding.npy') 28 | cfg.model.face_mask_path = os.path.join(cfg.project_dir, 'data', 'uv_face_mask.png') 29 | cfg.model.face_eye_mask_path = os.path.join(cfg.project_dir, 'data', 'uv_face_eye_mask.png') 30 | cfg.model.mean_tex_path = os.path.join(cfg.project_dir, 'data', 'mean_texture.jpg') 31 | cfg.model.tex_path = os.path.join(cfg.project_dir, 'data', 'FLAME_albedo_from_BFM.npz') 32 | cfg.model.tex_type = 'BFM' # BFM, FLAME, albedoMM 33 | cfg.model.uv_size = 256 34 | cfg.model.param_list = ['shape', 'tex', 'exp', 'pose', 'cam', 'light'] 35 | cfg.model.n_shape = 100 36 | cfg.model.n_tex = 50 37 | cfg.model.n_exp = 50 38 | cfg.model.n_cam = 3 39 | cfg.model.n_pose = 6 40 | cfg.model.n_light = 27 41 | cfg.model.jaw_type = 'aa' # default use axis angle, another option: euler. Note that: aa is not stable in the beginning 42 | 43 | 44 | 45 | cfg.model.model_type = "SPECTRE" 46 | 47 | cfg.model.temporal = True 48 | 49 | 50 | # ---------------------------------------------------------------------------- # 51 | # Options for Dataset 52 | # ---------------------------------------------------------------------------- # 53 | cfg.dataset = CN() 54 | cfg.dataset.LRS3_path = "/gpu-data3/filby/LRS3" 55 | cfg.dataset.LRS3_landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/LRS3/LRS3_landmarks" 56 | 57 | cfg.dataset.LRS3_path = "/gpu-data3/filby/LRS3" 58 | cfg.dataset.LRS3_landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/LRS3/LRS3_landmarks" 59 | 60 | cfg.dataset.LRS3_path = "/gpu-data3/filby/LRS3" 61 | cfg.dataset.LRS3_landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/LRS3/LRS3_landmarks" 62 | 63 | cfg.dataset.batch_size = 1 64 | cfg.dataset.K = 20 65 | cfg.dataset.num_workers = 8 66 | cfg.dataset.image_size = 224 # 224/500 67 | cfg.dataset.scale_min = 1.4 68 | cfg.dataset.scale_max = 1.8 69 | cfg.dataset.trans_scale = 0. 70 | cfg.dataset.fps = 25 71 | cfg.dataset.test_datasets = ['LRS3'] 72 | 73 | # ---------------------------------------------------------------------------- # 74 | # Options for training 75 | # ---------------------------------------------------------------------------- # 76 | cfg.train = CN() 77 | cfg.train.max_epochs = 6 78 | cfg.train.log_dir = 'logs' 79 | cfg.train.log_steps = 10 80 | cfg.train.vis_dir = 'train_images' 81 | cfg.train.vis_steps = 500 82 | cfg.train.write_summary = True 83 | cfg.train.checkpoint_steps = 10000 84 | cfg.train.val_vis_dir = 'val_images' 85 | 86 | cfg.train.evaluation_steps = 10000 87 | 88 | # ---------------------------------------------------------------------------- # 89 | # Options for Losses 90 | # ---------------------------------------------------------------------------- # 91 | cfg.loss = CN() 92 | cfg.loss.train = CN() 93 | 94 | cfg.model.use_tex = True 95 | cfg.model.regularization_type = 'nonlinear' 96 | cfg.model.backbone = 'mobilenetv2' # perceptual encoder backbone 97 | 98 | cfg.loss.train.landmark = 50 99 | cfg.loss.train.lip_landmarks = 0 100 | cfg.loss.train.relative_landmark = 50# 50 101 | cfg.loss.train.photometric_texture = 0 102 | cfg.loss.train.lipread = 2 103 | cfg.loss.train.jaw_reg = 200 104 | cfg.train.lr = 5e-5 105 | cfg.loss.train.expression = 0.5 106 | 107 | cfg.test_mode = False 108 | 109 | def get_cfg_defaults(): 110 | """Get a yacs CfgNode object with default values for my_project.""" 111 | # Return a clone so that the defaults will not be altered 112 | # This is for the "local variable" use pattern 113 | return cfg.clone() 114 | 115 | def update_cfg(cfg, cfg_file): 116 | cfg.merge_from_file(cfg_file) 117 | return cfg.clone() 118 | 119 | def parse_args(): 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('--output_dir', type=str, help='output path') 122 | parser.add_argument('--LRS3_path', default=None, type=str, help='path to LRS3 dataset') 123 | parser.add_argument('--LRS3_landmarks_path', default=None, type=str, help='path to LRS3 landmarks') 124 | parser.add_argument('--model_path', default=None, help='path to pretrained model') 125 | parser.add_argument('--batch-size', type=int, default=1, help='the batch size') 126 | parser.add_argument('--epochs', type=int, default=6, help='number of epochs to train for') 127 | parser.add_argument('--K', type=int, default=20, help='length of sampled frame sequence') 128 | parser.add_argument('--lipread', type=float, default=None, help='lipread loss weight') 129 | parser.add_argument('--expression', type=float, default=None, help='expression loss weight') 130 | parser.add_argument('--lr', type=float, default=None, help='learning rate') 131 | parser.add_argument('--landmark', type=float, default=None, help='landmark loss weight') 132 | parser.add_argument('--relative_landmark', type=float, default=None, help='relative landmark loss weight') 133 | parser.add_argument('--backbone', type=str, default='mobilenetv2', choices=['mobilenetv2', 'resnet50']) 134 | 135 | parser.add_argument('--test', action='store_true', help='test mode') 136 | parser.add_argument('--test_datasets', type=str, nargs='+', default=['LRS3'], help='test datasets') 137 | 138 | args = parser.parse_args() 139 | 140 | cfg = get_cfg_defaults() 141 | 142 | cfg.output_dir = args.output_dir 143 | 144 | if args.model_path is not None: 145 | cfg.pretrained_modelpath = args.model_path 146 | 147 | if args.batch_size is not None: 148 | cfg.dataset.batch_size = args.batch_size 149 | 150 | cfg.dataset.K = args.K 151 | 152 | if args.landmark is not None: 153 | cfg.loss.train.landmark = args.landmark 154 | 155 | if args.relative_landmark is not None: 156 | cfg.loss.train.relative_landmark = args.relative_landmark 157 | 158 | if args.lipread is not None: 159 | cfg.loss.train.lipread = args.lipread 160 | 161 | if args.expression is not None: 162 | cfg.loss.train.expression = args.expression 163 | 164 | if args.lr is not None: 165 | cfg.train.lr = args.lr 166 | 167 | if args.epochs is not None: 168 | cfg.train.max_epochs = args.epochs 169 | 170 | if args.LRS3_path is not None: 171 | cfg.dataset.LRS3_path = args.LRS3_path 172 | 173 | if args.LRS3_landmarks_path is not None: 174 | cfg.dataset.LRS3_landmarks_path = args.LRS3_landmarks_path 175 | 176 | cfg.model.backbone = args.backbone 177 | 178 | cfg.test_mode = args.test 179 | 180 | cfg.test_datasets = args.test_datasets 181 | 182 | return cfg 183 | -------------------------------------------------------------------------------- /external/spectre/configs/lipread_config.ini: -------------------------------------------------------------------------------- 1 | [input] 2 | modality=video 3 | v_fps=25 4 | 5 | [model] 6 | v_fps=25 7 | model_path=data/LRS3_V_WER32.3/model.pth 8 | model_conf=data/LRS3_V_WER32.3/model.json 9 | rnnlm= 10 | rnnlm_conf= 11 | 12 | [decode] 13 | beam_size=1 14 | penalty=0.5 15 | maxlenratio=0.0 16 | minlenratio=0.0 17 | ctc_weight=0.1 18 | lm_weight=0.6 19 | -------------------------------------------------------------------------------- /external/spectre/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/external/spectre/datasets/__init__.py -------------------------------------------------------------------------------- /external/spectre/datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def linear_interpolate(landmarks, start_idx, stop_idx): 4 | """linear_interpolate. 5 | 6 | :param landmarks: ndarray, input landmarks to be interpolated. 7 | :param start_idx: int, the start index for linear interpolation. 8 | :param stop_idx: int, the stop for linear interpolation. 9 | """ 10 | start_landmarks = landmarks[start_idx] 11 | stop_landmarks = landmarks[stop_idx] 12 | delta = stop_landmarks - start_landmarks 13 | for idx in range(1, stop_idx-start_idx): 14 | landmarks[start_idx+idx] = start_landmarks + idx/float(stop_idx-start_idx) * delta 15 | return landmarks 16 | 17 | def landmarks_interpolate(landmarks): 18 | """landmarks_interpolate. 19 | 20 | :param landmarks: List, the raw landmark (in-place) 21 | 22 | """ 23 | valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None] 24 | if not valid_frames_idx: 25 | return None 26 | for idx in range(1, len(valid_frames_idx)): 27 | if valid_frames_idx[idx] - valid_frames_idx[idx - 1] == 1: 28 | continue 29 | else: 30 | landmarks = linear_interpolate(landmarks, valid_frames_idx[idx - 1], valid_frames_idx[idx]) 31 | valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None] 32 | # -- Corner case: keep frames at the beginning or at the end failed to be detected. 33 | if valid_frames_idx: 34 | landmarks[:valid_frames_idx[0]] = [landmarks[valid_frames_idx[0]]] * valid_frames_idx[0] 35 | landmarks[valid_frames_idx[-1]:] = [landmarks[valid_frames_idx[-1]]] * (len(landmarks) - valid_frames_idx[-1]) 36 | valid_frames_idx = [idx for idx, _ in enumerate(landmarks) if _ is not None] 37 | assert len(valid_frames_idx) == len(landmarks), "not every frame has landmark" 38 | return landmarks 39 | 40 | 41 | def create_LRS3_lists(lrs3_path): 42 | from sklearn.model_selection import train_test_split 43 | import pickle 44 | trainval_folder_list = list(os.listdir(f"{lrs3_path}/trainval")) 45 | train_folder_list, val_folder_list = train_test_split(trainval_folder_list, test_size=0.2, random_state=42) 46 | 47 | 48 | train_list = [] 49 | for folder in train_folder_list: 50 | for file in os.listdir(os.path.join(f"{lrs3_path}/trainval", folder)): 51 | if file.endswith(".txt"): 52 | file_without_extension = file.split(".")[0] 53 | train_list.append(f"trainval/{folder}/{file_without_extension}") 54 | 55 | 56 | val_list = [] 57 | for folder in val_folder_list: 58 | for file in os.listdir(os.path.join(f"{lrs3_path}/trainval", folder)): 59 | if file.endswith(".txt"): 60 | file_without_extension = file.split(".")[0] 61 | val_list.append(f"trainval/{folder}/{file_without_extension}") 62 | 63 | # 64 | test_folder_list = list(os.listdir(f"{lrs3_path}/test")) 65 | test_list = [] 66 | for folder in test_folder_list: 67 | for file in os.listdir(os.path.join(f"{lrs3_path}/test", folder)): 68 | if file.endswith(".txt"): 69 | file_without_extension = file.split(".")[0] 70 | test_list.append(f"test/{folder}/{file_without_extension}") 71 | 72 | 73 | pickle.dump([train_list,val_list,test_list], open(f"data/LRS3_lists.pkl", "wb")) 74 | -------------------------------------------------------------------------------- /external/spectre/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import cv2 6 | from skimage.transform import estimate_transform, warp 7 | import random 8 | import pickle 9 | from .data_utils import landmarks_interpolate 10 | 11 | class SpectreDataset(Dataset): 12 | def __init__(self, data_list, landmarks_path, cfg, test=False): 13 | self.data_list = data_list 14 | self.image_size = 224 15 | self.K = cfg.K 16 | self.test = test 17 | self.cfg=cfg 18 | self.landmarks_path = landmarks_path 19 | 20 | if not self.test: 21 | self.scale = [1.4, 1.8] 22 | else: 23 | self.scale = 1.6 24 | 25 | def crop_face(self, frame, landmarks, scale=1.0): 26 | left = np.min(landmarks[:, 0]) 27 | right = np.max(landmarks[:, 0]) 28 | top = np.min(landmarks[:, 1]) 29 | bottom = np.max(landmarks[:, 1]) 30 | 31 | h, w, _ = frame.shape 32 | old_size = (right - left + bottom - top) / 2 33 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) # + old_size*0.1]) 34 | 35 | size = int(old_size * scale) 36 | 37 | # crop image 38 | src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2], 39 | [center[0] + size / 2, center[1] - size / 2]]) 40 | DST_PTS = np.array([[0, 0], [0, self.image_size - 1], [self.image_size - 1, 0]]) 41 | tform = estimate_transform('similarity', src_pts, DST_PTS) 42 | 43 | return tform 44 | 45 | def __len__(self): 46 | return len(self.data_list) 47 | 48 | def __getitem__(self, index): 49 | images_list = []; kpt_list = []; 50 | 51 | sample = self.data_list[index] 52 | 53 | landmarks_filename = os.path.join(self.landmarks_path, sample[0]+".pkl") 54 | folder_path = os.path.join(self.cfg.LRS3_path, sample[0]) 55 | 56 | with open(landmarks_filename, "rb") as pkl_file: 57 | landmarks = pickle.load(pkl_file) 58 | preprocessed_landmarks = landmarks_interpolate(landmarks) 59 | if preprocessed_landmarks is None: 60 | return None 61 | 62 | if self.test: 63 | frame_indices = list(range(len(landmarks))) 64 | else: 65 | if len(landmarks) < self.K: 66 | start_idx = 0 67 | end_idx = len(landmarks) 68 | else: 69 | start_idx = random.randint(0, len(landmarks) - self.K) 70 | end_idx = start_idx + self.K 71 | 72 | frame_indices = list(range(start_idx,end_idx)) 73 | 74 | if isinstance(self.scale, list): 75 | scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0] 76 | else: 77 | scale = self.scale 78 | 79 | for frame_idx in frame_indices: 80 | if "LRS3" in self.landmarks_path: 81 | frame = cv2.imread(os.path.join(folder_path,"%06d.jpg"%(frame_idx))) 82 | folder_path = os.path.join(self.cfg.LRS3_path, sample[0]) 83 | wav = folder_path + ".wav" 84 | else: # during test mode for other datasets 85 | if 'MEAD' in self.landmarks_path: 86 | folder_path = os.path.join("/gpu-data3/filby/MEAD/rendered/train/MEAD/images", sample[0]) 87 | frame = cv2.imread(os.path.join(folder_path,"%06d.png"%(frame_idx))) 88 | wav = folder_path.replace("images","wavs") + ".wav" 89 | else: 90 | folder_path = os.path.join("/gpu-data3/filby/EAVTTS/TCDTIMIT_preprocessed/images", sample[0]) 91 | frame = cv2.imread(os.path.join(folder_path,"%06d.png"%(frame_idx))) 92 | wav = folder_path.replace("images","wavs") + ".wav" 93 | 94 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 95 | kpt = preprocessed_landmarks[frame_idx] 96 | tform = self.crop_face(frame,kpt,scale) 97 | cropped_image = warp(frame, tform.inverse, output_shape=(self.image_size, self.image_size)) 98 | 99 | cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T 100 | 101 | cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2 - 1 102 | 103 | images_list.append(cropped_image.transpose(2,0,1)) 104 | kpt_list.append(cropped_kpt) 105 | 106 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32) #K,224,224,3 107 | kpt_array = torch.from_numpy(np.array(kpt_list)).type(dtype = torch.float32) #K,224,224,3 108 | 109 | # text = open(folder_path+".txt").readlines()[0].replace("Text:","").strip() 110 | text = sample[1] # open(folder_path+".txt").readlines()[0].replace("Text:","").strip() 111 | 112 | data_dict = { 113 | 'image': images_array, 114 | 'landmark': kpt_array, 115 | 'vid_name': sample[0], 116 | 'wav_path': wav, # this is only used for evaluation - you can remove this key from the dictionary if you don't need it 117 | 'text': text, # this is only used for evaluation - you can remove this key from the dictionary if you don't need it 118 | } 119 | 120 | return data_dict 121 | 122 | 123 | def get_datasets_LRS3(config=None): 124 | if not os.path.exists('data/LRS3_lists.pkl'): 125 | print('Creating train, validation, and test lists for LRS3... (This only happens once)') 126 | 127 | from .data_utils import create_LRS3_lists 128 | create_LRS3_lists(config.LRS3_path) 129 | 130 | 131 | lists = pickle.load(open("data/LRS3_lists.pkl", "rb")) 132 | train_list = lists[0] 133 | val_list = lists[1] 134 | test_list = lists[2] 135 | landmarks_path = config.LRS3_landmarks_path 136 | return SpectreDataset(train_list, landmarks_path, cfg=config), SpectreDataset(val_list, landmarks_path, cfg=config), SpectreDataset(test_list, landmarks_path, 137 | cfg=config, 138 | test=True) 139 | -------------------------------------------------------------------------------- /external/spectre/datasets/extra_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .datasets import SpectreDataset 3 | 4 | 5 | def get_datasets_MEAD(config=None): 6 | import pandas as pd 7 | questionnaire_list = pd.read_csv("../utils/MEAD_test_set_final.csv") 8 | test_list = [(x[0],x[1]) for x in zip(questionnaire_list.name,questionnaire_list.text)] 9 | landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/MEAD_images_25fps" 10 | 11 | return None, None, SpectreDataset(test_list, landmarks_path, cfg=config, test=True) 12 | 13 | 14 | def get_datasets_TCDTIMIT(config=None): 15 | tcd_root = "/gpu-data3/filby/EAVTTS" 16 | 17 | landmarks_path = "../Visual_Speech_Recognition_for_Multiple_Languages/landmarks/TCDTIMIT_images_25fps" 18 | 19 | root = f"{tcd_root}/TCDTIMIT_preprocessed/TCDSpkrIndepTrainSet.scp" 20 | files = open(root).readlines() 21 | train_list = [] 22 | for file in files: 23 | f = file.strip().split("/") 24 | new_name = f"{f[0]}_{f[-1]}" 25 | 26 | ff = "/".join([f[0],f[1],f[2]]) 27 | 28 | text = open(os.path.join(f"{tcd_root}/TCDTIMITprocessing/downloadTCDTIMIT/volunteers",ff,f[-1].upper().replace(".MP4",".txt"))).readlines() 29 | 30 | text = " ".join([x.split()[2].strip() for x in text]) 31 | 32 | train_list.append((new_name.split(".")[0],text)) 33 | 34 | 35 | root = f"{tcd_root}/TCDTIMIT_preprocessed/TCDSpkrIndepTestSet.scp" 36 | files = open(root).readlines() 37 | test_list = [] 38 | for file in files: 39 | f = file.strip().split("/") 40 | new_name = f"{f[0]}_{f[-1]}" 41 | 42 | ff = "/".join([f[0],f[1],f[2]]) 43 | 44 | text = open(os.path.join(f"{tcd_root}/TCDTIMITprocessing/downloadTCDTIMIT/volunteers",ff,f[-1].upper().replace(".MP4",".txt"))).readlines() 45 | 46 | text = " ".join([x.split()[2].strip() for x in text]) 47 | 48 | test_list.append((new_name.split(".")[0],text.upper())) 49 | 50 | 51 | return SpectreDataset(train_list, landmarks_path, cfg=config), SpectreDataset(test_list, landmarks_path, cfg=config), SpectreDataset(test_list, landmarks_path, cfg=config, test=True) -------------------------------------------------------------------------------- /external/spectre/demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import argparse 5 | # sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 6 | import os, sys 7 | import torch 8 | import numpy as np 9 | import cv2 10 | from skimage.transform import estimate_transform, warp, resize, rescale 11 | import scipy.io 12 | import collections 13 | from tqdm import tqdm 14 | from datasets.data_utils import landmarks_interpolate 15 | from src.spectre import SPECTRE 16 | from config import cfg as spectre_cfg 17 | from src.utils.util import tensor2video 18 | import torchvision 19 | 20 | def extract_frames(video_path, detect_landmarks=True): 21 | videofolder = os.path.splitext(video_path)[0] 22 | os.makedirs(videofolder, exist_ok=True) 23 | vidcap = cv2.VideoCapture(video_path) 24 | 25 | if detect_landmarks is True: 26 | from external.Visual_Speech_Recognition_for_Multiple_Languages.tracker.face_tracker import FaceTracker 27 | from external.Visual_Speech_Recognition_for_Multiple_Languages.tracker.utils import get_landmarks 28 | face_tracker = FaceTracker() 29 | 30 | imagepath_list = [] 31 | count = 0 32 | 33 | face_info = collections.defaultdict(list) 34 | 35 | fps = vidcap.get(cv2.CAP_PROP_FPS) 36 | 37 | with tqdm(total=int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar: 38 | while True: 39 | success, image = vidcap.read() 40 | if not success: 41 | break 42 | 43 | if detect_landmarks is True: 44 | detected_faces = face_tracker.face_detector(image, rgb=False) 45 | # -- face alignment 46 | landmarks, scores = face_tracker.landmark_detector(image, detected_faces, rgb=False) 47 | face_info['bbox'].append(detected_faces) 48 | face_info['landmarks'].append(landmarks) 49 | face_info['landmarks_scores'].append(scores) 50 | 51 | imagepath = os.path.join(videofolder, f'{count:06d}.jpg') 52 | cv2.imwrite(imagepath, image) # save frame as JPEG file 53 | count += 1 54 | imagepath_list.append(imagepath) 55 | pbar.update(1) 56 | pbar.set_description("Preprocessing frame %d" % count) 57 | 58 | landmarks = get_landmarks(face_info) 59 | print('video frames are stored in {}'.format(videofolder)) 60 | return imagepath_list, landmarks, videofolder, fps 61 | 62 | 63 | 64 | def crop_face(frame, landmarks, scale=1.0): 65 | image_size = 224 66 | left = np.min(landmarks[:, 0]) 67 | right = np.max(landmarks[:, 0]) 68 | top = np.min(landmarks[:, 1]) 69 | bottom = np.max(landmarks[:, 1]) 70 | 71 | h, w, _ = frame.shape 72 | old_size = (right - left + bottom - top) / 2 73 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) 74 | 75 | size = int(old_size * scale) 76 | 77 | src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], [center[0] - size / 2, center[1] + size / 2], 78 | [center[0] + size / 2, center[1] - size / 2]]) 79 | DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]]) 80 | tform = estimate_transform('similarity', src_pts, DST_PTS) 81 | 82 | return tform 83 | 84 | 85 | 86 | def main(args): 87 | args.crop_face = True 88 | spectre_cfg.pretrained_modelpath = "pretrained/spectre_model.tar" 89 | spectre_cfg.model.use_tex = False 90 | 91 | spectre = SPECTRE(spectre_cfg, args.device) 92 | spectre.eval() 93 | 94 | image_paths, landmarks, videofolder, fps = extract_frames(args.input, detect_landmarks=args.crop_face) 95 | if args.crop_face: 96 | landmarks = landmarks_interpolate(landmarks) 97 | if landmarks is None: 98 | print('No faces detected in input {}'.format(args.input)) 99 | 100 | 101 | original_video_length = len(image_paths) 102 | """ SPECTRE uses a temporal convolution of size 5. 103 | Thus, in order to predict the parameters for a contiguous video with need to 104 | process the video in chunks of overlap 2, dropping values which were computed from the 105 | temporal kernel which uses pad 'same'. For the start and end of the video we 106 | pad using the first and last frame of the video. 107 | e.g., consider a video of size 48 frames and we want to predict it in chunks of 20 frames 108 | (due to memory limitations). We first pad the video two frames at the start and end using 109 | the first and last frames correspondingly, making the video 52 frames length. 110 | 111 | Then we process independently the following chunks: 112 | [[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19] 113 | [16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35] 114 | [32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51]] 115 | 116 | In the first chunk, after computing the 3DMM params we drop 0,1 and 18,19, since they were computed 117 | from the temporal kernel with padding (we followed the same procedure in training and computed loss 118 | only from valid outputs of the temporal kernel) In the second chunk, we drop 16,17 and 34,35, and in 119 | the last chunk we drop 32,33 and 50,51. As a result we get: 120 | [2..17], [18..33], [34..49] (end included) which correspond to all frames of the original video 121 | (removing the initial padding). 122 | """ 123 | 124 | # pad 125 | image_paths.insert(0,image_paths[0]) 126 | image_paths.insert(0,image_paths[0]) 127 | image_paths.append(image_paths[-1]) 128 | image_paths.append(image_paths[-1]) 129 | 130 | landmarks.insert(0,landmarks[0]) 131 | landmarks.insert(0,landmarks[0]) 132 | landmarks.append(landmarks[-1]) 133 | landmarks.append(landmarks[-1]) 134 | 135 | landmarks = np.array(landmarks) 136 | 137 | L = 50 # chunk size 138 | 139 | # create lists of overlapping indices 140 | indices = list(range(len(image_paths))) 141 | overlapping_indices = [indices[i: i + L] for i in range(0, len(indices), L-4)] 142 | 143 | if len(overlapping_indices[-1]) < 5: 144 | # if the last chunk has less than 5 frames, pad it with the semilast frame 145 | overlapping_indices[-2] = overlapping_indices[-2] + overlapping_indices[-1] 146 | overlapping_indices[-2] = np.unique(overlapping_indices[-2]).tolist() 147 | overlapping_indices = overlapping_indices[:-1] 148 | 149 | overlapping_indices = np.array(overlapping_indices) 150 | 151 | image_paths = np.array(image_paths) # do this to index with multiple indices 152 | all_shape_images = [] 153 | all_images = [] 154 | 155 | with torch.no_grad(): 156 | for chunk_id in range(len(overlapping_indices)): 157 | print('Processing frames {} to {}'.format(overlapping_indices[chunk_id][0], overlapping_indices[chunk_id][-1])) 158 | image_paths_chunk = image_paths[overlapping_indices[chunk_id]] 159 | 160 | landmarks_chunk = landmarks[overlapping_indices[chunk_id]] if args.crop_face else None 161 | 162 | images_list = [] 163 | 164 | """ load each image and crop it around the face if necessary """ 165 | for j in range(len(image_paths_chunk)): 166 | frame = cv2.imread(image_paths_chunk[j]) 167 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 168 | kpt = landmarks_chunk[j] 169 | 170 | tform = crop_face(frame,kpt,scale=1.6) 171 | cropped_image = warp(frame, tform.inverse, output_shape=(224, 224)) 172 | 173 | images_list.append(cropped_image.transpose(2,0,1)) 174 | 175 | images_array = torch.from_numpy(np.array(images_list)).type(dtype = torch.float32).to(args.device) #K,224,224,3 176 | 177 | codedict, initial_deca_exp, initial_deca_jaw = spectre.encode(images_array) 178 | codedict['exp'] = codedict['exp'] + initial_deca_exp 179 | codedict['pose'][..., 3:] = codedict['pose'][..., 3:] + initial_deca_jaw 180 | 181 | for key in codedict.keys(): 182 | """ filter out invalid indices - see explanation at the top of the function """ 183 | 184 | if chunk_id == 0 and chunk_id == len(overlapping_indices) - 1: 185 | pass 186 | elif chunk_id == 0: 187 | codedict[key] = codedict[key][:-2] 188 | elif chunk_id == len(overlapping_indices) - 1: 189 | codedict[key] = codedict[key][2:] 190 | else: 191 | codedict[key] = codedict[key][2:-2] 192 | 193 | opdict, visdict = spectre.decode(codedict, rendering=True, vis_lmk=False, return_vis=True) 194 | all_shape_images.append(visdict['shape_images'].detach().cpu()) 195 | all_images.append(codedict['images'].detach().cpu()) 196 | 197 | vid_shape = tensor2video(torch.cat(all_shape_images, dim=0))[2:-2] # remove padding 198 | vid_orig = tensor2video(torch.cat(all_images, dim=0))[2:-2] # remove padding 199 | grid_vid = np.concatenate((vid_shape, vid_orig), axis=2) 200 | 201 | assert original_video_length == len(vid_shape) 202 | 203 | if args.audio: 204 | import librosa 205 | wav, sr = librosa.load(args.input) 206 | wav = torch.FloatTensor(wav) 207 | if len(wav.shape) == 1: 208 | wav = wav.unsqueeze(0) 209 | 210 | torchvision.io.write_video(videofolder+"_shape.mp4", vid_shape, fps=fps, audio_codec='aac', audio_array=wav, audio_fps=sr) 211 | torchvision.io.write_video(videofolder+"_grid.mp4", grid_vid, fps=fps, 212 | audio_codec='aac', audio_array=wav, audio_fps=sr) 213 | 214 | else: 215 | torchvision.io.write_video(videofolder+"_shape.mp4", vid_shape, fps=fps) 216 | torchvision.io.write_video(videofolder+"_grid.mp4", grid_vid, fps=fps) 217 | 218 | 219 | if __name__ == '__main__': 220 | parser = argparse.ArgumentParser(description='DECA: Detailed Expression Capture and Animation') 221 | 222 | parser.add_argument('-i', '--input', default='examples', type=str, 223 | help='path to the test data, can be image folder, image path, image list, video') 224 | # parser.add_argument('-o', '--outpath', default='examples/results', type=str, 225 | # help='path to the output directory, where results(obj, txt files) will be stored.') 226 | parser.add_argument('--device', default='cuda', type=str, 227 | help='set device, cpu for using cpu') 228 | parser.add_argument('--audio', action='store_true', 229 | help='extract audio from the original video and add it to the output video') 230 | 231 | main(parser.parse_args()) -------------------------------------------------------------------------------- /external/spectre/get_training_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # file adapted from MICA https://raw.githubusercontent.com/Zielon/MICA 3 | # 4 | echo -e "\nDownloading deca_model..." 5 | # 6 | FILEID=1rp8kdyLPvErw2dTmqtjISRVvQLj6Yzje 7 | FILENAME=./data/deca_model.tar 8 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='${FILEID} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O $FILENAME && rm -rf /tmp/cookies.txt 9 | 10 | 11 | echo "To download the Emotion Recognition from EMOCA which is used from SPECTRE for expression loss, please register at:", 12 | echo -e '\e]8;;https://emoca.is.tue.mpg.de\ahttps://emoca.is.tue.mpg.de\e]8;;\a' 13 | while true; do 14 | read -p "I have registered and agreed to the license terms at https://emoca.is.tue.mpg.de? (y/n)" yn 15 | case $yn in 16 | [Yy]* ) break;; 17 | [Nn]* ) exit;; 18 | * ) echo "Please answer yes or no.";; 19 | esac 20 | done 21 | 22 | wget https://download.is.tue.mpg.de/emoca/assets/EmotionRecognition/image_based_networks/ResNet50.zip -O ResNet50.zip 23 | unzip ResNet50.zip -d data/ 24 | rm ResNet50.zip 25 | 26 | echo -e "\nDownloading lipreading pretrained model..." 27 | 28 | FILEID=1yHd4QwC7K_9Ro2OM_hC7pKUT2URPvm_f 29 | FILENAME=LRS3_V_WER32.3.zip 30 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='${FILEID} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O $FILENAME && rm -rf /tmp/cookies.txt 31 | unzip $FILENAME -d data/ 32 | rm LRS3_V_WER32.3.zip 33 | 34 | echo -e "\nDownloading landmarks for LRS3 dataset ..." 35 | 36 | gdown --id 1QRdOgeHvmKK8t4hsceFVf_BSpidQfUyW 37 | unzip LRS3_landmarks.zip -d data/ 38 | rm LRS3_landmarks.zip 39 | 40 | 41 | 42 | echo -e "\nInstallation has finished!" 43 | -------------------------------------------------------------------------------- /external/spectre/main.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import yaml 4 | import torch.backends.cudnn as cudnn 5 | import torch 6 | import shutil 7 | 8 | 9 | def main(cfg): 10 | # creat folders 11 | os.makedirs(os.path.join(cfg.output_dir, cfg.train.log_dir), exist_ok=True) 12 | 13 | if cfg.test_mode is False: 14 | os.makedirs(os.path.join(cfg.output_dir, cfg.train.vis_dir), exist_ok=True) 15 | os.makedirs(os.path.join(cfg.output_dir, cfg.train.val_vis_dir), exist_ok=True) 16 | with open(os.path.join(cfg.output_dir, 'full_config.yaml'), 'w') as f: 17 | yaml.dump(cfg, f, default_flow_style=False) 18 | 19 | # cudnn related setting 20 | cudnn.benchmark = True 21 | torch.backends.cudnn.deterministic = False 22 | torch.backends.cudnn.enabled = True 23 | 24 | # start training 25 | from src.trainer_spectre import Trainer 26 | from src.spectre import SPECTRE 27 | spectre = SPECTRE(cfg) 28 | 29 | trainer = Trainer(model=spectre, config=cfg) 30 | 31 | if cfg.test_mode: 32 | trainer.prepare_data() 33 | trainer.evaluate(trainer.test_datasets) 34 | else: 35 | trainer.fit() 36 | 37 | if __name__ == '__main__': 38 | from config import parse_args 39 | cfg = parse_args() 40 | cfg.exp_name = cfg.output_dir 41 | 42 | main(cfg) 43 | -------------------------------------------------------------------------------- /external/spectre/quick_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # file adapted from MICA https://github.com/Zielon/MICA/ 3 | 4 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } 5 | 6 | # username and password input 7 | echo -e "\nIf you do not have an account you can register at https://flame.is.tue.mpg.de/ following the installation instruction." 8 | read -p "Username (FLAME):" username 9 | read -p "Password (FLAME):" password 10 | username=$(urle $username) 11 | password=$(urle $password) 12 | 13 | echo -e "\nDownloading FLAME..." 14 | mkdir -p data/FLAME2020/ 15 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=flame&sfile=FLAME2020.zip&resume=1' -O './FLAME2020.zip' --no-check-certificate --continue 16 | unzip FLAME2020.zip -d data/FLAME2020/ 17 | rm -rf FLAME2020.zip 18 | 19 | echo -e "\nDownload pretrained SPECTRE model..." 20 | gdown --id 1vmWX6QmXGPnXTXWFgj67oHzOoOmxBh6B 21 | mkdir -p pretrained/ 22 | mv spectre_model.tar pretrained/ 23 | 24 | 25 | -------------------------------------------------------------------------------- /external/spectre/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import random 5 | import math 6 | import numpy as np 7 | import cv2 8 | import torch 9 | import torchvision 10 | from tqdm import tqdm 11 | from .config import cfg as spectre_cfg 12 | from .src.spectre import SPECTRE 13 | 14 | 15 | class Render(): 16 | def __init__(self, device='cuda:0') -> None: 17 | # model 18 | self.device = device 19 | spectre_cfg.pretrained_modelpath = "external/spectre/pretrained/HDTF_pretrained/00032000.tar" 20 | spectre_cfg.model.use_tex = False 21 | self.spectre = SPECTRE(spectre_cfg, device=self.device) 22 | self.spectre.eval() 23 | 24 | def forward(self, exp_out, exp, data_info): 25 | 'input: expression coefficients' 26 | 'output: mesh' 27 | n = self.cfg['trainer']['visual_images'] 28 | if self.cfg['datasets']['dataset']=='mead': 29 | # template 30 | codedict = {} 31 | codedict['pose'] = torch.zeros((n*3, 6), dtype=torch.float).to(self.device) 32 | codedict['exp'] = torch.zeros((n*3, 50), dtype=torch.float).to(self.device) 33 | codedict['shape'] = torch.zeros((n*3, 100), dtype=torch.float).to(self.device) 34 | codedict['tex'] = torch.zeros((n*3, 50), dtype=torch.float).to(self.device) 35 | codedict['cam'] = torch.zeros((n*3, 3), dtype=torch.float).to(self.device) 36 | self.codedict = codedict 37 | # true coefficients 38 | coefficient_path = os.path.join(data_info, 'crop_head_info.npy') 39 | coefficient_info = np.load(coefficient_path, allow_pickle=True).item()['face3d_encode'] 40 | coefficients = get_coefficients(coefficient_info) 41 | for key in coefficients: 42 | coefficients[key] = torch.FloatTensor(coefficients[key]).to(self.device) 43 | start_vis = random.randint(0,exp.shape[1]-1-n) # 起始帧 44 | 45 | self.codedict['exp'][0:n] = exp_out[0, start_vis: start_vis+n,:-3] # 生成的参数在平均脸上 46 | self.codedict['exp'][n:2*n] = exp[0, start_vis: start_vis+n,:-3] # ground_truth参数在平均脸上 47 | self.codedict['exp'][2*n:3*n] = coefficients['exp'][start_vis: start_vis+n,:] # ground_truth在原脸上 48 | 49 | self.codedict['pose'][0:n, 3:] = exp_out[0, start_vis: start_vis+n,-3:] # jaw pose 50 | self.codedict['pose'][n:2*n, 3:] = exp[0, start_vis: start_vis+n,-3:] 51 | 52 | self.codedict['cam'][0:n] = coefficients['cam'][start_vis: start_vis+n, :] # 取n帧的cam 53 | self.codedict['cam'][n:2*n] = coefficients['cam'][start_vis: start_vis+n, :] 54 | self.codedict['cam'][2*n:3*n] = coefficients['cam'][start_vis: start_vis+n, :] 55 | 56 | self.codedict['pose'][2*n:3*n] = coefficients['pose'][start_vis: start_vis+n, :] # 取n帧的pose 57 | self.codedict['shape'][2*n:3*n] = coefficients['shape'][start_vis: start_vis+n, :] # # 取n帧的shape 58 | 59 | elif self.cfg['datasets']['dataset']=='mote': 60 | # template 61 | codedict = {} 62 | codedict['pose'] = torch.zeros((n*2, 6), dtype=torch.float).to(self.device) 63 | codedict['exp'] = torch.zeros((n*2, 50), dtype=torch.float).to(self.device) 64 | codedict['shape'] = torch.zeros((n*2, 100), dtype=torch.float).to(self.device) 65 | codedict['tex'] = torch.zeros((n*2, 50), dtype=torch.float).to(self.device) 66 | codedict['cam'] = torch.zeros((n*2, 3), dtype=torch.float).to(self.device) 67 | self.codedict = codedict 68 | # true coefficients 69 | # coefficient_path = os.path.join(self.cfg['datasets']['data_root'], data_info[0][0], data_info[1][0], 'train1_all.npz') 70 | # coefficient_info = np.load(coefficient_path, allow_pickle=True)['face'][-1*self.cfg['datasets']['eval_frames']:, :] 71 | # coefficients = get_coefficients(coefficient_info) 72 | # for key in coefficients: 73 | # coefficients[key] = torch.FloatTensor(coefficients[key]).to(self.device) 74 | start_vis = random.randint(0,exp.shape[1]-1-n) # 起始帧 75 | 76 | self.codedict['exp'][0:n] = exp_out[0, start_vis: start_vis+n,:-3] # 生成的参数在平均脸上 77 | self.codedict['exp'][n:2*n] = exp[0, start_vis: start_vis+n,:-3] # ground_truth参数在平均脸上 78 | 79 | self.codedict['pose'][0:n, 3:] = exp_out[0, start_vis: start_vis+n,-3:] # jaw pose 80 | self.codedict['pose'][n:2*n, 3:] = exp[0, start_vis: start_vis+n,-3:] 81 | 82 | cam = torch.tensor([8.8093824, 0.00314824, 0.043486204]).unsqueeze(0).repeat(n, 1) # cam 83 | self.codedict['cam'][0:n] = cam 84 | self.codedict['cam'][n:2*n] = cam 85 | 86 | opdict = self.spectre.decode(self.codedict, rendering=True, vis_lmk=False, return_vis=False) 87 | # rendered_images = torchvision.utils.make_grid(opdict['rendered_images'].detach().cpu(), nrow=n) 88 | return opdict['rendered_images'] 89 | 90 | def infer(self, exp, exp_gt=None, render_batch=100): 91 | 'input: expression coefficients' 92 | 'output: mesh' 93 | n = exp.shape[1] 94 | coefficients = {} 95 | coefficients['pose'] = torch.zeros((n, 6), dtype=torch.float).to(exp.device) 96 | coefficients['exp'] = torch.zeros((n, 50), dtype=torch.float).to(exp.device) 97 | coefficients['shape'] = torch.zeros((n, 100), dtype=torch.float).to(exp.device) 98 | coefficients['tex'] = torch.zeros((n, 50), dtype=torch.float).to(exp.device) 99 | coefficients['cam'] = torch.zeros((n, 3), dtype=torch.float).to(exp.device) 100 | coefficients_pred = copy.deepcopy(coefficients) 101 | # cam = torch.tensor([8.8093824, 0.00314824, 0.043486204]).unsqueeze(0).repeat(n, 1).to(exp.device) # cam 102 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(exp.device) # cam 103 | for key in coefficients: 104 | # coefficients[key] = torch.FloatTensor(torch.from_numpy(coefficients[key])).to(exp.device) 105 | # coefficients_pred[key] = torch.FloatTensor(torch.from_numpy(coefficients_pred[key])).to(exp.device) 106 | if key == 'exp': 107 | if exp_gt is not None: 108 | coefficients[key] = exp_gt[0][:, :-3] 109 | coefficients_pred[key] = exp[0][:, :-3] 110 | elif key == 'pose': 111 | if exp_gt is not None: 112 | coefficients[key][:, -3:] = exp_gt[0][:, -3:] 113 | coefficients_pred[key][:, -3:] = exp[0][:, -3:] 114 | elif key == 'cam': 115 | coefficients[key] = cam[:, :] 116 | coefficients_pred[key] = cam[:, :] 117 | n_batch = int(math.ceil(n/render_batch)) 118 | rendered_images, rendered_images_pred = [], [] 119 | for i in range(n_batch): 120 | coefficients_render, coefficients_pred_render = {}, {} 121 | for k in coefficients: 122 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n) 123 | coefficients_render[k] = coefficients[k][start_f: end_f] 124 | coefficients_pred_render[k] = coefficients_pred[k][start_f: end_f] 125 | 126 | if exp_gt is not None: 127 | opdict = self.spectre.decode(coefficients_render, rendering=True, vis_lmk=False, return_vis=False) 128 | rendered_images.append(opdict['rendered_images'].detach().cpu()) 129 | opdict_pred = self.spectre.decode(coefficients_pred_render, rendering=True, vis_lmk=False, return_vis=False) 130 | rendered_images_pred.append(opdict_pred['rendered_images'].detach().cpu()) 131 | if exp_gt is not None: 132 | rendered_images_cat = torch.cat(rendered_images, dim=0) 133 | else: 134 | rendered_images_cat = None 135 | rendered_images_pred_cat = torch.cat(rendered_images_pred, dim=0) 136 | 137 | return rendered_images_cat, rendered_images_pred_cat 138 | # opdict = self.spectre.decode(coefficients, rendering=True, vis_lmk=False, return_vis=False) 139 | # opdict_pred = self.spectre.decode(coefficients_pred, rendering=True, vis_lmk=False, return_vis=False) 140 | # return opdict['rendered_images'], opdict_pred['rendered_images'] 141 | 142 | def exp2mesh(self, coefficients_info, pose0=True, render_batch=100): 143 | n = coefficients_info.shape[0] 144 | if coefficients_info.shape[-1] == 53: 145 | coefficients = {} 146 | coefficients['pose'] = torch.zeros((n, 6), dtype=torch.float).to(coefficients_info.device) 147 | coefficients['exp'] = torch.zeros((n, 50), dtype=torch.float).to(coefficients_info.device) 148 | coefficients['shape'] = torch.zeros((n, 100), dtype=torch.float).to(coefficients_info.device) 149 | coefficients['tex'] = torch.zeros((n, 50), dtype=torch.float).to(coefficients_info.device) 150 | coefficients['cam'] = torch.zeros((n, 3), dtype=torch.float).to(coefficients_info.device) 151 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(coefficients_info.device) # cam 152 | for key in coefficients: 153 | # coefficients[key] = torch.FloatTensor(torch.from_numpy(coefficients[key])).to(coefficients_info.device) 154 | # coefficients_pred[key] = torch.FloatTensor(torch.from_numpy(coefficients_pred[key])).to(coefficients_info.device) 155 | if key == 'exp': 156 | coefficients[key] = coefficients_info[:, 3:] 157 | elif key == 'pose': 158 | coefficients[key][:, -3:] = coefficients_info[:, :3] 159 | elif key == 'cam': 160 | coefficients[key] = cam[:, :] 161 | elif coefficients_info.shape[-1] == 209 or coefficients_info.shape[-1] == 213 or coefficients_info.shape[-1] == 236: 162 | coefficients = get_coefficients(coefficients_info) 163 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(coefficients_info.device) # cam 164 | for key in coefficients: 165 | coefficients[key] = torch.FloatTensor(coefficients[key]).to(coefficients_info.device) 166 | if pose0: 167 | if key == 'pose': 168 | coefficients[key][:, :3] = torch.zeros_like(coefficients[key][:, :3]) 169 | elif key == 'shape' or key == 'tex': 170 | coefficients[key] = torch.zeros_like(coefficients[key]) 171 | elif key == 'cam': 172 | coefficients[key] = cam 173 | 174 | n_batch = int(math.ceil(n/render_batch)) 175 | rendered_images = [] 176 | vertices = [] 177 | for i in tqdm(range(n_batch)): 178 | coefficients_batch = {} 179 | for k in coefficients: 180 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n) 181 | coefficients_batch[k] = coefficients[k][start_f: end_f] 182 | opdict = self.spectre.decode(coefficients_batch, rendering=True, vis_lmk=False, return_vis=False) 183 | rendered_images.append(opdict['rendered_images'].detach().cpu()) 184 | vertices.append(opdict['verts'].detach().cpu()) 185 | rendered_images_cat = torch.cat(rendered_images, dim=0) 186 | vertices_cat = torch.cat(vertices, dim=0) 187 | return rendered_images_cat, vertices_cat 188 | 189 | def coff2mesh(self, coeff, pose0=True, render_batch=100): 190 | n = coeff.shape[0] 191 | 192 | assert coeff.shape[-1] == 209 or coeff.shape[-1] == 213 or coeff.shape[-1] == 236 193 | coefficients = get_coefficients(coeff) 194 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(coeff.device) # cam 195 | coefficients['cam'] = cam 196 | if pose0: 197 | coefficients['pose'][:, :3] = torch.zeros_like(coefficients['pose'][:, :3]) 198 | 199 | n_batch = int(math.ceil(n/render_batch)) 200 | rendered_images = [] 201 | vertices = [] 202 | for i in tqdm(range(n_batch)): 203 | coefficients_batch = {} 204 | for k in coefficients: 205 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n) 206 | coefficients_batch[k] = coefficients[k][start_f: end_f] 207 | opdict = self.spectre.decode(coefficients_batch, rendering=True, vis_lmk=False, return_vis=False) 208 | rendered_images.append(opdict['rendered_images'].detach().cpu()) 209 | vertices.append(opdict['verts'].detach().cpu()) 210 | rendered_images_cat = torch.cat(rendered_images, dim=0) 211 | vertices_cat = torch.cat(vertices, dim=0) 212 | return rendered_images_cat, vertices_cat 213 | 214 | def coff2mesh_rawcam(self, coeff, pose0=True, render_batch=100): 215 | n = coeff.shape[0] 216 | 217 | if coeff.shape[-1] == 209 or coeff.shape[-1] == 213 or coeff.shape[-1] == 236: 218 | coefficients = get_coefficients(coeff) 219 | if pose0: 220 | for key in coefficients: 221 | if key == 'pose': 222 | coefficients[key][:, :3] = torch.zeros_like(coefficients[key][:, :3]) 223 | # cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(coeff.device) # cam 224 | # coefficients['cam'] = cam 225 | # for key in coefficients: 226 | # coefficients[key] = torch.FloatTensor(coefficients[key]).to(coefficients_info.device) 227 | # if pose0: 228 | # if key == 'pose': 229 | # coefficients[key][:, :3] = torch.zeros_like(coefficients[key][:, :3]) 230 | # elif key == 'shape' or key == 'tex': 231 | # coefficients[key] = torch.zeros_like(coefficients[key]) 232 | # elif key == 'cam': 233 | # coefficients[key] = cam 234 | 235 | n_batch = int(math.ceil(n/render_batch)) 236 | rendered_images = [] 237 | vertices = [] 238 | for i in tqdm(range(n_batch)): 239 | coefficients_batch = {} 240 | for k in coefficients: 241 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n) 242 | coefficients_batch[k] = coefficients[k][start_f: end_f] 243 | opdict = self.spectre.decode(coefficients_batch, rendering=True, vis_lmk=False, return_vis=False) 244 | rendered_images.append(opdict['rendered_images'].detach().cpu()) 245 | vertices.append(opdict['verts'].detach().cpu()) 246 | rendered_images_cat = torch.cat(rendered_images, dim=0) 247 | vertices_cat = torch.cat(vertices, dim=0) 248 | return rendered_images_cat, vertices_cat 249 | 250 | 251 | def get_coefficients(coefficient_info): 252 | coefficient_dict = {} 253 | coefficient_dict['pose'] = coefficient_info[:, :6] 254 | coefficient_dict['exp'] = coefficient_info[:, 6:56] 255 | coefficient_dict['shape'] = coefficient_info[:, 56:156] 256 | coefficient_dict['tex'] = coefficient_info[:, 156:206] 257 | coefficient_dict['cam'] = coefficient_info[:, 206:209] 258 | # coefficient_dict['light'] = coefficient_info[:, 209:236] 259 | return coefficient_dict 260 | -------------------------------------------------------------------------------- /external/spectre/requirements.txt: -------------------------------------------------------------------------------- 1 | scikit_image==0.19.3 2 | kornia==0.6.6 3 | chumpy==0.70 4 | librosa==0.9.2 5 | av==9.2.0 6 | loguru==0.6.0 7 | tensorboard==2.9.1 8 | pytorch_lightning==1.5 9 | opencv-python==4.6.0.66 10 | phonemizer==3.2.1 11 | jiwer==2.3.0 -------------------------------------------------------------------------------- /external/spectre/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/external/spectre/src/__init__.py -------------------------------------------------------------------------------- /external/spectre/src/models/FLAME.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import pickle 20 | import torch.nn.functional as F 21 | 22 | from .lbs import lbs, batch_rodrigues, vertices2landmarks, rot_mat_to_euler 23 | 24 | def to_tensor(array, dtype=torch.float32): 25 | if 'torch.tensor' not in str(type(array)): 26 | return torch.tensor(array, dtype=dtype) 27 | def to_np(array, dtype=np.float32): 28 | if 'scipy.sparse' in str(type(array)): 29 | array = array.todense() 30 | return np.array(array, dtype=dtype) 31 | 32 | class Struct(object): 33 | def __init__(self, **kwargs): 34 | for key, val in kwargs.items(): 35 | setattr(self, key, val) 36 | 37 | class FLAME(nn.Module): 38 | """ 39 | borrowed from https://github.com/soubhiksanyal/FLAME_PyTorch/blob/master/FLAME.py 40 | Given flame parameters this class generates a differentiable FLAME function 41 | which outputs the a mesh and 2D/3D facial landmarks 42 | """ 43 | def __init__(self, config): 44 | super(FLAME, self).__init__() 45 | # print("creating the FLAME Decoder") 46 | with open(config.flame_model_path, 'rb') as f: 47 | ss = pickle.load(f, encoding='latin1') 48 | flame_model = Struct(**ss) 49 | 50 | self.dtype = torch.float32 51 | self.register_buffer('faces_tensor', to_tensor(to_np(flame_model.f, dtype=np.int64), dtype=torch.long)) 52 | # The vertices of the template model 53 | self.register_buffer('v_template', to_tensor(to_np(flame_model.v_template), dtype=self.dtype)) 54 | # The shape components and expression 55 | shapedirs = to_tensor(to_np(flame_model.shapedirs), dtype=self.dtype) 56 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2) 57 | self.register_buffer('shapedirs', shapedirs) 58 | # The pose components 59 | num_pose_basis = flame_model.posedirs.shape[-1] 60 | posedirs = np.reshape(flame_model.posedirs, [-1, num_pose_basis]).T 61 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) 62 | # 63 | self.register_buffer('J_regressor', to_tensor(to_np(flame_model.J_regressor), dtype=self.dtype)) 64 | parents = to_tensor(to_np(flame_model.kintree_table[0])).long(); parents[0] = -1 65 | self.register_buffer('parents', parents) 66 | self.register_buffer('lbs_weights', to_tensor(to_np(flame_model.weights), dtype=self.dtype)) 67 | 68 | # Fixing Eyeball and neck rotation 69 | default_eyball_pose = torch.zeros([1, 6], dtype=self.dtype, requires_grad=False) 70 | self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose, 71 | requires_grad=False)) 72 | default_neck_pose = torch.zeros([1, 3], dtype=self.dtype, requires_grad=False) 73 | self.register_parameter('neck_pose', nn.Parameter(default_neck_pose, 74 | requires_grad=False)) 75 | 76 | # Static and Dynamic Landmark embeddings for FLAME 77 | lmk_embeddings = np.load(config.flame_lmk_embedding_path, allow_pickle=True, encoding='latin1') 78 | lmk_embeddings = lmk_embeddings[()] 79 | self.register_buffer('lmk_faces_idx', torch.from_numpy(lmk_embeddings['static_lmk_faces_idx']).long()) 80 | self.register_buffer('lmk_bary_coords', torch.from_numpy(lmk_embeddings['static_lmk_bary_coords']).to(self.dtype)) 81 | self.register_buffer('dynamic_lmk_faces_idx', lmk_embeddings['dynamic_lmk_faces_idx'].long()) 82 | self.register_buffer('dynamic_lmk_bary_coords', lmk_embeddings['dynamic_lmk_bary_coords'].to(self.dtype)) 83 | self.register_buffer('full_lmk_faces_idx', torch.from_numpy(lmk_embeddings['full_lmk_faces_idx']).long()) 84 | self.register_buffer('full_lmk_bary_coords', torch.from_numpy(lmk_embeddings['full_lmk_bary_coords']).to(self.dtype)) 85 | 86 | neck_kin_chain = []; NECK_IDX=1 87 | curr_idx = torch.tensor(NECK_IDX, dtype=torch.long) 88 | while curr_idx != -1: 89 | neck_kin_chain.append(curr_idx) 90 | curr_idx = self.parents[curr_idx] 91 | self.register_buffer('neck_kin_chain', torch.stack(neck_kin_chain)) 92 | 93 | def _find_dynamic_lmk_idx_and_bcoords(self, pose, dynamic_lmk_faces_idx, 94 | dynamic_lmk_b_coords, 95 | neck_kin_chain, dtype=torch.float32): 96 | """ 97 | Selects the face contour depending on the reletive position of the head 98 | Input: 99 | vertices: N X num_of_vertices X 3 100 | pose: N X full pose 101 | dynamic_lmk_faces_idx: The list of contour face indexes 102 | dynamic_lmk_b_coords: The list of contour barycentric weights 103 | neck_kin_chain: The tree to consider for the relative rotation 104 | dtype: Data type 105 | return: 106 | The contour face indexes and the corresponding barycentric weights 107 | """ 108 | 109 | batch_size = pose.shape[0] 110 | 111 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 112 | neck_kin_chain) 113 | rot_mats = batch_rodrigues( 114 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 115 | 116 | rel_rot_mat = torch.eye(3, device=pose.device, 117 | dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1) 118 | for idx in range(len(neck_kin_chain)): 119 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 120 | 121 | y_rot_angle = torch.round( 122 | torch.clamp(rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 123 | max=39)).to(dtype=torch.long) 124 | 125 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 126 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 127 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 128 | y_rot_angle = (neg_mask * neg_vals + 129 | (1 - neg_mask) * y_rot_angle) 130 | 131 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 132 | 0, y_rot_angle) 133 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 134 | 0, y_rot_angle) 135 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 136 | 137 | def _vertices2landmarks(self, vertices, faces, lmk_faces_idx, lmk_bary_coords): 138 | """ 139 | Calculates landmarks by barycentric interpolation 140 | Input: 141 | vertices: torch.tensor NxVx3, dtype = torch.float32 142 | The tensor of input vertices 143 | faces: torch.tensor (N*F)x3, dtype = torch.long 144 | The faces of the mesh 145 | lmk_faces_idx: torch.tensor N X L, dtype = torch.long 146 | The tensor with the indices of the faces used to calculate the 147 | landmarks. 148 | lmk_bary_coords: torch.tensor N X L X 3, dtype = torch.float32 149 | The tensor of barycentric coordinates that are used to interpolate 150 | the landmarks 151 | 152 | Returns: 153 | landmarks: torch.tensor NxLx3, dtype = torch.float32 154 | The coordinates of the landmarks for each mesh in the batch 155 | """ 156 | # Extract the indices of the vertices for each face 157 | # NxLx3 158 | batch_size, num_verts = vertices.shape[:dd2] 159 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 160 | 1, -1, 3).view(batch_size, lmk_faces_idx.shape[1], -1) 161 | 162 | lmk_faces += torch.arange(batch_size, dtype=torch.long).view(-1, 1, 1).to( 163 | device=vertices.device) * num_verts 164 | 165 | lmk_vertices = vertices.view(-1, 3)[lmk_faces] 166 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 167 | return landmarks 168 | 169 | def seletec_3d68(self, vertices): 170 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 171 | self.full_lmk_faces_idx.repeat(vertices.shape[0], 1), 172 | self.full_lmk_bary_coords.repeat(vertices.shape[0], 1, 1)) 173 | return landmarks3d 174 | 175 | def forward(self, shape_params=None, expression_params=None, pose_params=None, eye_pose_params=None): 176 | """ 177 | Input: 178 | shape_params: N X number of shape parameters 179 | expression_params: N X number of expression parameters 180 | pose_params: N X number of pose parameters (6) 181 | return:d 182 | vertices: N X V X 3 183 | landmarks: N X number of landmarks X 3 184 | """ 185 | batch_size = shape_params.shape[0] 186 | if pose_params is None: 187 | pose_params = self.eye_pose.expand(batch_size, -1) 188 | if eye_pose_params is None: 189 | eye_pose_params = self.eye_pose.expand(batch_size, -1) 190 | betas = torch.cat([shape_params, expression_params], dim=1) 191 | full_pose = torch.cat([pose_params[:, :3], self.neck_pose.expand(batch_size, -1), pose_params[:, 3:], eye_pose_params], dim=1) 192 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) 193 | 194 | vertices, _ = lbs(betas, full_pose, template_vertices, 195 | self.shapedirs, self.posedirs, 196 | self.J_regressor, self.parents, 197 | self.lbs_weights, dtype=self.dtype) 198 | 199 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) 200 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) 201 | 202 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords( 203 | full_pose, self.dynamic_lmk_faces_idx, 204 | self.dynamic_lmk_bary_coords, 205 | self.neck_kin_chain, dtype=self.dtype) 206 | lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1) 207 | lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords], 1) 208 | 209 | landmarks2d = vertices2landmarks(vertices, self.faces_tensor, 210 | lmk_faces_idx, 211 | lmk_bary_coords) 212 | bz = vertices.shape[0] 213 | landmarks3d = vertices2landmarks(vertices, self.faces_tensor, 214 | self.full_lmk_faces_idx.repeat(bz, 1), 215 | self.full_lmk_bary_coords.repeat(bz, 1, 1)) 216 | return vertices, landmarks2d, landmarks3d 217 | 218 | class FLAMETex(nn.Module): 219 | """ 220 | FLAME texture: 221 | https://github.com/TimoBolkart/TF_FLAME/blob/ade0ab152300ec5f0e8555d6765411555c5ed43d/sample_texture.py#L64 222 | FLAME texture converted from BFM: 223 | https://github.com/TimoBolkart/BFM_to_FLAME 224 | """ 225 | def __init__(self, config): 226 | super(FLAMETex, self).__init__() 227 | if config.tex_type == 'BFM': 228 | mu_key = 'MU' 229 | pc_key = 'PC' 230 | n_pc = 199 231 | tex_path = config.tex_path 232 | tex_space = np.load(tex_path) 233 | texture_mean = tex_space[mu_key].reshape(1, -1) 234 | texture_basis = tex_space[pc_key].reshape(-1, n_pc) 235 | 236 | elif config.tex_type == 'FLAME': 237 | mu_key = 'mean' 238 | pc_key = 'tex_dir' 239 | n_pc = 200 240 | tex_path = config.flame_tex_path 241 | tex_space = np.load(tex_path) 242 | texture_mean = tex_space[mu_key].reshape(1, -1)/255. 243 | texture_basis = tex_space[pc_key].reshape(-1, n_pc)/255. 244 | else: 245 | print('texture type ', config.tex_type, 'not exist!') 246 | raise NotImplementedError 247 | 248 | n_tex = config.n_tex 249 | num_components = texture_basis.shape[1] 250 | texture_mean = torch.from_numpy(texture_mean).float()[None,...] 251 | texture_basis = torch.from_numpy(texture_basis[:,:n_tex]).float()[None,...] 252 | self.register_buffer('texture_mean', texture_mean) 253 | self.register_buffer('texture_basis', texture_basis) 254 | 255 | 256 | def forward(self, texcode): 257 | ''' 258 | texcode: [batchsize, n_tex] 259 | texture: [bz, 3, 256, 256], range: 0-1 260 | ''' 261 | 262 | bs = texcode.shape[0] 263 | texcode = texcode[:1] 264 | 265 | # we use the same (first frame) texture for all frames 266 | 267 | texture = self.texture_mean + (self.texture_basis*texcode[:,None,:]).sum(-1) 268 | 269 | texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0,3,1,2) 270 | texture = F.interpolate(texture, [256, 256]) 271 | texture = texture[:,[2,1,0], :,:].repeat(bs,1,1,1) 272 | return texture 273 | -------------------------------------------------------------------------------- /external/spectre/src/models/encoders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | from . import resnet 6 | 7 | 8 | class PerceptualEncoder(nn.Module): 9 | def __init__(self, outsize, cfg): 10 | super(PerceptualEncoder, self).__init__() 11 | if cfg.backbone == "mobilenetv2": 12 | self.encoder = torch.hub.load('pytorch/vision:v0.8.1', 'mobilenet_v2', pretrained=True) 13 | feature_size = 1280 14 | elif cfg.backbone == "resnet50": 15 | self.encoder = resnet.load_ResNet50Model() #out: 2048 16 | feature_size = 2048 17 | 18 | ### regressor 19 | self.temporal = nn.Sequential( 20 | nn.Conv1d(in_channels=feature_size, out_channels=256, kernel_size=5, stride=1, padding=2), 21 | nn.BatchNorm1d(256), 22 | nn.ReLU() 23 | ) 24 | 25 | self.layers = nn.Sequential( 26 | nn.Linear(256, 53), 27 | ) 28 | 29 | self.backbone = cfg.backbone 30 | 31 | def forward(self, inputs): 32 | if self.backbone == 'resnet50': 33 | features = self.encoder(inputs).squeeze(-1).squeeze(-1) 34 | else: 35 | features = self.encoder.features(inputs) 36 | features = nn.functional.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) 37 | 38 | features = features 39 | features = features.permute(1,0).unsqueeze(0) 40 | 41 | features = self.temporal(features) 42 | 43 | features = features.squeeze(0).permute(1,0) 44 | 45 | parameters = self.layers(features) 46 | 47 | parameters[...,50] = F.relu(parameters[...,50]) # jaw x is highly improbably negative and can introduce artifacts 48 | 49 | return parameters[...,:50], parameters[...,50:] 50 | 51 | 52 | class ResnetEncoder(nn.Module): 53 | def __init__(self, outsize): 54 | super(ResnetEncoder, self).__init__() 55 | 56 | feature_size = 2048 57 | 58 | self.encoder = resnet.load_ResNet50Model() #out: 2048 59 | ### regressor 60 | self.layers = nn.Sequential( 61 | nn.Linear(feature_size, 1024), 62 | nn.ReLU(), 63 | nn.Linear(1024, outsize) 64 | ) 65 | 66 | def forward(self, inputs): 67 | features = self.encoder(inputs) 68 | parameters = self.layers(features) 69 | 70 | return parameters 71 | 72 | -------------------------------------------------------------------------------- /external/spectre/src/models/expression_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import torch.nn as nn 17 | from torchvision import models 18 | from . import resnet 19 | 20 | 21 | 22 | class ExpressionLossNet(nn.Module): 23 | """ Code borrowed from EMOCA https://github.com/radekd91/emoca """ 24 | def __init__(self): 25 | super(ExpressionLossNet, self).__init__() 26 | 27 | self.backbone = resnet.load_ResNet50Model() #out: 2048 28 | 29 | self.linear = nn.Sequential( 30 | nn.Linear(2048, 10)) 31 | 32 | def forward2(self, inputs): 33 | features = self.backbone(inputs) 34 | out = self.linear(features) 35 | return features, out 36 | 37 | def forward(self, inputs): 38 | features = self.backbone(inputs) 39 | return features 40 | -------------------------------------------------------------------------------- /external/spectre/src/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Soubhik Sanyal 3 | Copyright (c) 2019, Soubhik Sanyal 4 | All rights reserved. 5 | Loads different resnet models 6 | """ 7 | ''' 8 | file: Resnet.py 9 | date: 2018_05_02 10 | author: zhangxiong(1025679612@qq.com) 11 | mark: copied from pytorch source code 12 | ''' 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch 17 | from torch.nn.parameter import Parameter 18 | import torch.optim as optim 19 | import numpy as np 20 | import math 21 | import torchvision 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, block, layers, num_classes=1000): 25 | self.inplanes = 64 26 | super(ResNet, self).__init__() 27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 28 | bias=False) 29 | self.bn1 = nn.BatchNorm2d(64) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 32 | self.layer1 = self._make_layer(block, 64, layers[0]) 33 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 34 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 35 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 36 | self.avgpool = nn.AvgPool2d(7, stride=1) 37 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 38 | 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2. / n)) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | 47 | def _make_layer(self, block, planes, blocks, stride=1): 48 | downsample = None 49 | if stride != 1 or self.inplanes != planes * block.expansion: 50 | downsample = nn.Sequential( 51 | nn.Conv2d(self.inplanes, planes * block.expansion, 52 | kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(planes * block.expansion), 54 | ) 55 | 56 | layers = [] 57 | layers.append(block(self.inplanes, planes, stride, downsample)) 58 | self.inplanes = planes * block.expansion 59 | for i in range(1, blocks): 60 | layers.append(block(self.inplanes, planes)) 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | x = self.conv1(x) 66 | x = self.bn1(x) 67 | x = self.relu(x) 68 | x = self.maxpool(x) 69 | 70 | x = self.layer1(x) 71 | x = self.layer2(x) 72 | x = self.layer3(x) 73 | x1 = self.layer4(x) 74 | 75 | x2 = self.avgpool(x1) 76 | x2 = x2.view(x2.size(0), -1) 77 | # x = self.fc(x) 78 | ## x2: [bz, 2048] for shape 79 | ## x1: [bz, 2048, 7, 7] for texture 80 | return x2 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | 85 | def __init__(self, inplanes, planes, stride=1, downsample=None): 86 | super(Bottleneck, self).__init__() 87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(planes) 89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 90 | padding=1, bias=False) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 93 | self.bn3 = nn.BatchNorm2d(planes * 4) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out += residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | def conv3x3(in_planes, out_planes, stride=1): 121 | """3x3 convolution with padding""" 122 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 123 | padding=1, bias=False) 124 | 125 | class BasicBlock(nn.Module): 126 | expansion = 1 127 | 128 | def __init__(self, inplanes, planes, stride=1, downsample=None): 129 | super(BasicBlock, self).__init__() 130 | self.conv1 = conv3x3(inplanes, planes, stride) 131 | self.bn1 = nn.BatchNorm2d(planes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.conv2 = conv3x3(planes, planes) 134 | self.bn2 = nn.BatchNorm2d(planes) 135 | self.downsample = downsample 136 | self.stride = stride 137 | 138 | def forward(self, x): 139 | residual = x 140 | 141 | out = self.conv1(x) 142 | out = self.bn1(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv2(out) 146 | out = self.bn2(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | def copy_parameter_from_resnet(model, resnet_dict): 157 | cur_state_dict = model.state_dict() 158 | # import ipdb; ipdb.set_trace() 159 | for name, param in list(resnet_dict.items())[0:None]: 160 | if name not in cur_state_dict: 161 | # print(name, ' not available in reconstructed resnet') 162 | continue 163 | if isinstance(param, Parameter): 164 | param = param.data 165 | try: 166 | cur_state_dict[name].copy_(param) 167 | except: 168 | # print(name, ' is inconsistent!') 169 | continue 170 | # print('copy resnet state dict finished!') 171 | # import ipdb; ipdb.set_trace() 172 | 173 | 174 | def load_ResNet50Model(): 175 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 176 | copy_parameter_from_resnet(model, torchvision.models.resnet50(pretrained = True).state_dict()) 177 | return model 178 | 179 | def load_ResNet101Model(): 180 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 181 | copy_parameter_from_resnet(model, torchvision.models.resnet101(pretrained = True).state_dict()) 182 | return model 183 | 184 | def load_ResNet152Model(): 185 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 186 | copy_parameter_from_resnet(model, torchvision.models.resnet152(pretrained = True).state_dict()) 187 | return model 188 | 189 | # model.load_state_dict(checkpoint['model_state_dict']) 190 | 191 | 192 | ######## Unet 193 | 194 | class DoubleConv(nn.Module): 195 | """(convolution => [BN] => ReLU) * 2""" 196 | 197 | def __init__(self, in_channels, out_channels): 198 | super().__init__() 199 | self.double_conv = nn.Sequential( 200 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 201 | nn.BatchNorm2d(out_channels), 202 | nn.ReLU(inplace=True), 203 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 204 | nn.BatchNorm2d(out_channels), 205 | nn.ReLU(inplace=True) 206 | ) 207 | 208 | def forward(self, x): 209 | return self.double_conv(x) 210 | 211 | 212 | class Down(nn.Module): 213 | """Downscaling with maxpool then double conv""" 214 | 215 | def __init__(self, in_channels, out_channels): 216 | super().__init__() 217 | self.maxpool_conv = nn.Sequential( 218 | nn.MaxPool2d(2), 219 | DoubleConv(in_channels, out_channels) 220 | ) 221 | 222 | def forward(self, x): 223 | return self.maxpool_conv(x) 224 | 225 | 226 | class Up(nn.Module): 227 | """Upscaling then double conv""" 228 | 229 | def __init__(self, in_channels, out_channels, bilinear=True): 230 | super().__init__() 231 | 232 | # if bilinear, use the normal convolutions to reduce the number of channels 233 | if bilinear: 234 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 235 | else: 236 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) 237 | 238 | self.conv = DoubleConv(in_channels, out_channels) 239 | 240 | def forward(self, x1, x2): 241 | x1 = self.up(x1) 242 | # input is CHW 243 | diffY = x2.size()[2] - x1.size()[2] 244 | diffX = x2.size()[3] - x1.size()[3] 245 | 246 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 247 | diffY // 2, diffY - diffY // 2]) 248 | # if you have padding issues, see 249 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 250 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 251 | x = torch.cat([x2, x1], dim=1) 252 | return self.conv(x) 253 | 254 | 255 | class OutConv(nn.Module): 256 | def __init__(self, in_channels, out_channels): 257 | super(OutConv, self).__init__() 258 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 259 | 260 | def forward(self, x): 261 | return self.conv(x) -------------------------------------------------------------------------------- /external/spectre/src/spectre.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # Using this computer program means that you agree to the terms 6 | # in the LICENSE file included with this software distribution. 7 | # Any use not explicitly granted by the LICENSE is prohibited. 8 | # 9 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 10 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 11 | # for Intelligent Systems. All rights reserved. 12 | # 13 | # For comments or questions, please email us at deca@tue.mpg.de 14 | # For commercial licensing contact, please contact ps-license@tuebingen.mpg.de 15 | 16 | import os 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from .models.encoders import PerceptualEncoder 22 | from .utils.renderer import SRenderY, set_rasterizer 23 | from .models.encoders import ResnetEncoder 24 | from .models.FLAME import FLAME, FLAMETex 25 | from .utils import util 26 | from .utils.tensor_cropper import transform_points 27 | from skimage.io import imread 28 | torch.backends.cudnn.benchmark = True 29 | import numpy as np 30 | 31 | class SPECTRE(nn.Module): 32 | def __init__(self, config=None, device='cuda'): 33 | super(SPECTRE, self).__init__() 34 | self.cfg = config 35 | self.device = device 36 | self.image_size = self.cfg.dataset.image_size 37 | self.uv_size = self.cfg.model.uv_size 38 | self._create_model(self.cfg.model) 39 | self._setup_renderer(self.cfg.model) 40 | 41 | 42 | def _setup_renderer(self, model_cfg): 43 | set_rasterizer(self.cfg.rasterizer_type) 44 | self.render = SRenderY(self.image_size, obj_filename=model_cfg.topology_path, uv_size=model_cfg.uv_size, rasterizer_type=self.cfg.rasterizer_type).to(self.device) 45 | # face mask for rendering details 46 | mask = imread(model_cfg.face_eye_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous() 47 | self.uv_face_eye_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device) 48 | mask = imread(model_cfg.face_mask_path).astype(np.float32)/255.; mask = torch.from_numpy(mask[:,:,0])[None,None,:,:].contiguous() 49 | self.uv_face_mask = F.interpolate(mask, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device) 50 | # displacement correction 51 | fixed_dis = np.load(model_cfg.fixed_displacement_path) 52 | self.fixed_uv_dis = torch.tensor(fixed_dis).float().to(self.device) 53 | # mean texture 54 | mean_texture = imread(model_cfg.mean_tex_path).astype(np.float32)/255.; mean_texture = torch.from_numpy(mean_texture.transpose(2,0,1))[None,:,:,:].contiguous() 55 | self.mean_texture = F.interpolate(mean_texture, [model_cfg.uv_size, model_cfg.uv_size]).to(self.device) 56 | # dense mesh template, for save detail mesh 57 | self.dense_template = np.load(model_cfg.dense_template_path, allow_pickle=True, encoding='latin1').item() 58 | 59 | def _create_model(self, model_cfg): 60 | # set up parameters 61 | self.n_param = model_cfg.n_shape + model_cfg.n_tex + model_cfg.n_exp + model_cfg.n_pose + model_cfg.n_cam + model_cfg.n_light 62 | self.n_cond = model_cfg.n_exp + 3 # exp + jaw pose 63 | self.num_list = [model_cfg.n_shape, model_cfg.n_tex, model_cfg.n_exp, model_cfg.n_pose, model_cfg.n_cam, 64 | model_cfg.n_light] 65 | self.param_dict = {i: model_cfg.get('n_' + i) for i in model_cfg.param_list} 66 | 67 | # encoders 68 | self.E_flame = ResnetEncoder(outsize=self.n_param).to(self.device) 69 | 70 | self.E_expression = PerceptualEncoder(model_cfg.n_exp, model_cfg).to(self.device) 71 | 72 | # decoders 73 | self.flame = FLAME(model_cfg).to(self.device) 74 | if model_cfg.use_tex: 75 | self.flametex = FLAMETex(model_cfg).to(self.device) 76 | 77 | # resume model 78 | model_path = self.cfg.pretrained_modelpath 79 | if os.path.exists(model_path): 80 | # print(f'trained model found. load {model_path}') 81 | checkpoint = torch.load(model_path) 82 | 83 | if 'state_dict' in checkpoint.keys(): 84 | self.checkpoint = checkpoint['state_dict'] 85 | else: 86 | self.checkpoint = checkpoint 87 | 88 | processed_checkpoint = {} 89 | processed_checkpoint["E_flame"] = {} 90 | processed_checkpoint["E_expression"] = {} 91 | if 'deca' in list(self.checkpoint.keys())[0]: 92 | for key in self.checkpoint.keys(): 93 | # print(key) 94 | k = key.replace("deca.","") 95 | if "E_flame" in key: 96 | processed_checkpoint["E_flame"][k.replace("E_flame.","")] = self.checkpoint[key]#.replace("E_flame","") 97 | elif "E_expression" in key: 98 | processed_checkpoint["E_expression"][k.replace("E_expression.","")] = self.checkpoint[key]#.replace("E_flame","") 99 | else: 100 | pass 101 | 102 | else: 103 | processed_checkpoint = self.checkpoint 104 | 105 | 106 | self.E_flame.load_state_dict(processed_checkpoint['E_flame'], strict=True) 107 | try: 108 | m,u = self.E_expression.load_state_dict(processed_checkpoint['E_expression'], strict=True) 109 | # print('Missing keys', m) 110 | # print('Unexpected keys', u) 111 | # pass 112 | except Exception as e: 113 | print(f'Missing keys {e} in expression encoder weights. If starting training from scratch this is normal.') 114 | else: 115 | raise(f'please check model path: {model_path}') 116 | 117 | # eval mode 118 | self.E_flame.eval() 119 | 120 | self.E_expression.eval() 121 | 122 | self.E_flame.requires_grad_(False) 123 | 124 | 125 | def decompose_code(self, code, num_dict): 126 | ''' Convert a flattened parameter vector to a dictionary of parameters 127 | code_dict.keys() = ['shape', 'tex', 'exp', 'pose', 'cam', 'light'] 128 | ''' 129 | code_dict = {} 130 | start = 0 131 | for key in num_dict: 132 | end = start + int(num_dict[key]) 133 | code_dict[key] = code[:, start:end] 134 | start = end 135 | if key == 'light': 136 | code_dict[key] = code_dict[key].reshape(code_dict[key].shape[0], 9, 3) 137 | return code_dict 138 | 139 | def encode(self, images): 140 | with torch.no_grad(): 141 | parameters = self.E_flame(images) 142 | 143 | codedict = self.decompose_code(parameters, self.param_dict) 144 | deca_exp = codedict['exp'].clone() 145 | deca_jaw = codedict['pose'][:,3:].clone() 146 | 147 | codedict['images'] = images 148 | 149 | codedict['exp'], jaw = self.E_expression(images) 150 | codedict['pose'][:, 3:] = jaw 151 | 152 | return codedict, deca_exp, deca_jaw 153 | 154 | 155 | def decode(self, codedict, rendering=True, vis_lmk=False, return_vis=False, 156 | render_orig=False, original_image=None, tform=None): 157 | # images = codedict['images'] 158 | # batch_size = images.shape[0] 159 | batch_size = 1 160 | 161 | ## decode 162 | verts, landmarks2d, landmarks3d = self.flame(shape_params=codedict['shape'], expression_params=codedict['exp'], 163 | pose_params=codedict['pose']) 164 | if self.cfg.model.use_tex: 165 | albedo = self.flametex(codedict['tex']).detach() 166 | else: 167 | albedo = torch.zeros([batch_size, 3, self.uv_size, self.uv_size], device=self.device) 168 | landmarks3d_world = landmarks3d.clone() 169 | 170 | ## projection 171 | landmarks2d = util.batch_orth_proj(landmarks2d, codedict['cam'])[:, :, :2]; 172 | landmarks2d[:, :, 1:] = -landmarks2d[:, :, 173 | 1:] 174 | landmarks3d = util.batch_orth_proj(landmarks3d, codedict['cam']); 175 | landmarks3d[:, :, 1:] = -landmarks3d[:, :, 176 | 1:] 177 | trans_verts = util.batch_orth_proj(verts, codedict['cam']); 178 | trans_verts[:, :, 1:] = -trans_verts[:, :, 1:] 179 | opdict = { 180 | 'verts': verts, 181 | 'trans_verts': trans_verts, 182 | 'landmarks2d': landmarks2d, 183 | 'landmarks3d': landmarks3d, 184 | 'landmarks3d_world': landmarks3d_world, 185 | } 186 | 187 | if rendering and render_orig and original_image is not None and tform is not None: 188 | points_scale = [self.image_size, self.image_size] 189 | _, _, h, w = original_image.shape 190 | trans_verts = transform_points(trans_verts, tform, points_scale, [h, w]) 191 | landmarks2d = transform_points(landmarks2d, tform, points_scale, [h, w]) 192 | landmarks3d = transform_points(landmarks3d, tform, points_scale, [h, w]) 193 | background = images 194 | else: 195 | h, w = self.image_size, self.image_size 196 | background = None 197 | 198 | 199 | if rendering: 200 | if self.cfg.model.use_tex: 201 | ops = self.render(verts, trans_verts, albedo, codedict['light']) 202 | ## output 203 | opdict['predicted_inner_mouth'] = ops['predicted_inner_mouth'] 204 | opdict['grid'] = ops['grid'] 205 | opdict['rendered_images'] = ops['images'] 206 | opdict['alpha_images'] = ops['alpha_images'] 207 | opdict['normal_images'] = ops['normal_images'] 208 | opdict['images'] = images 209 | 210 | else: 211 | shape_images, _, grid, alpha_images, pos_mask = self.render.render_shape(verts, trans_verts, h=h, w=w, 212 | images=background, 213 | return_grid=True, 214 | return_pos=True) 215 | 216 | opdict['rendered_images'] = shape_images 217 | 218 | if self.cfg.model.use_tex: 219 | opdict['albedo'] = albedo 220 | 221 | if vis_lmk: 222 | landmarks3d_vis = self.visofp(ops['transformed_normals']) # /self.image_size 223 | landmarks3d = torch.cat([landmarks3d, landmarks3d_vis], dim=2) 224 | opdict['landmarks3d'] = landmarks3d 225 | 226 | if return_vis: 227 | ## render shape 228 | shape_images, _, grid, alpha_images, pos_mask = self.render.render_shape(verts, trans_verts, h=h, w=w, 229 | images=background, return_grid=True, return_pos=True) 230 | 231 | # opdict['uv_texture_gt'] = uv_texture_gt 232 | visdict = { 233 | # 'inputs': images, 234 | 'landmarks2d': util.tensor_vis_landmarks(images, landmarks2d), 235 | 'landmarks3d': util.tensor_vis_landmarks(images, landmarks3d), 236 | 'shape_images': shape_images, 237 | # 'rendered_images': ops['images'] 238 | } 239 | 240 | return opdict, visdict 241 | 242 | else: 243 | return opdict 244 | 245 | def train(self): 246 | self.E_expression.train() 247 | 248 | self.E_flame.eval() 249 | 250 | 251 | def eval(self): 252 | self.E_expression.eval() 253 | self.E_flame.eval() 254 | 255 | 256 | def model_dict(self): 257 | return { 258 | 'E_flame': self.E_flame.state_dict(), 259 | 'E_expression': self.E_expression.state_dict(), 260 | } 261 | -------------------------------------------------------------------------------- /external/spectre/src/utils/lossfunc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def l2_distance(verts1, verts2): 8 | return torch.sqrt(((verts1 - verts2)**2).sum(2)).mean(1).mean() 9 | 10 | ### ------------------------------------- Losses/Regularizations for vertices 11 | def batch_kp_2d_l1_loss(real_2d_kp, predicted_2d_kp, weights=None): 12 | """ 13 | Computes the l1 loss between the ground truth keypoints and the predicted keypoints 14 | Inputs: 15 | kp_gt : N x K x 3 16 | kp_pred: N x K x 2 17 | """ 18 | if weights is not None: 19 | real_2d_kp[:,:,2] = weights[None,:]*real_2d_kp[:,:,2] 20 | kp_gt = real_2d_kp.view(-1, 3) 21 | kp_pred = predicted_2d_kp.contiguous().view(-1, 2) 22 | vis = kp_gt[:, 2] 23 | k = torch.sum(vis) * 2.0 + 1e-8 24 | 25 | dif_abs = torch.abs(kp_gt[:, :2] - kp_pred).sum(1) 26 | 27 | return torch.matmul(dif_abs, vis) * 1.0 / k 28 | 29 | def landmark_loss(predicted_landmarks, landmarks_gt, weight=1.): 30 | if torch.is_tensor(landmarks_gt) is not True: 31 | real_2d = torch.cat(landmarks_gt).cuda() 32 | else: 33 | real_2d = torch.cat([landmarks_gt, torch.ones((landmarks_gt.shape[0], 68, 1)).cuda()], dim=-1) 34 | 35 | loss_lmk_2d = batch_kp_2d_l1_loss(real_2d, predicted_landmarks) 36 | return loss_lmk_2d * weight 37 | 38 | 39 | def weighted_landmark_loss(predicted_landmarks, landmarks_gt, weight=1.): 40 | #smaller inner landmark weights 41 | # (predicted_theta, predicted_verts, predicted_landmarks) = ringnet_outputs[-1] 42 | # import ipdb; ipdb.set_trace() 43 | real_2d = landmarks_gt 44 | weights = torch.ones((68,)).cuda() 45 | weights[5:7] = 2 46 | weights[10:12] = 2 47 | # nose points 48 | weights[27:36] = 1.5 49 | weights[30] = 3 50 | weights[31] = 3 51 | weights[35] = 3 52 | 53 | # set mouth to zero 54 | weights[60:68] = 0 55 | weights[48:60] = 0 56 | weights[48] = 0 57 | weights[54] = 0 58 | 59 | 60 | # weights[36:48] = 0 # these are eyes 61 | 62 | loss_lmk_2d = batch_kp_2d_l1_loss(real_2d, predicted_landmarks, weights) 63 | return loss_lmk_2d * weight 64 | 65 | 66 | def rel_dis(landmarks): 67 | 68 | lip_right = landmarks[:, [57, 51, 48, 60, 61, 62, 63], :] 69 | lip_left = landmarks[:, [8, 33, 54, 64, 67, 66, 65], :] 70 | 71 | # lip_right = landmarks[:, [61, 62, 63], :] 72 | # lip_left = landmarks[:, [67, 66, 65], :] 73 | 74 | dis = torch.sqrt(((lip_right - lip_left) ** 2).sum(2)) # [bz, 4] 75 | 76 | return dis 77 | 78 | def relative_landmark_loss(predicted_landmarks, landmarks_gt, weight=1.): 79 | if torch.is_tensor(landmarks_gt) is not True: 80 | real_2d = torch.cat(landmarks_gt)#.cuda() 81 | else: 82 | real_2d = torch.cat([landmarks_gt, torch.ones((landmarks_gt.shape[0], 68, 1)).to(device=predicted_landmarks.device) #.cuda() 83 | ], dim=-1) 84 | pred_lipd = rel_dis(predicted_landmarks[:, :, :2]) 85 | gt_lipd = rel_dis(real_2d[:, :, :2]) 86 | 87 | loss = (pred_lipd - gt_lipd).abs().mean() 88 | # loss = F.mse_loss(pred_lipd, gt_lipd) 89 | 90 | return loss.mean() 91 | 92 | -------------------------------------------------------------------------------- /external/spectre/src/utils/rotation_converter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ''' Rotation Converter 4 | Repre: euler angle(3), angle axis(3), rotation matrix(3x3), quaternion(4) 5 | ref: https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/core/conversions.html# 6 | "pi", 7 | "rad2deg", 8 | "deg2rad", 9 | # "angle_axis_to_rotation_matrix", batch_rodrigues 10 | "rotation_matrix_to_angle_axis", 11 | "rotation_matrix_to_quaternion", 12 | "quaternion_to_angle_axis", 13 | # "angle_axis_to_quaternion", 14 | 15 | euler2quat_conversion_sanity_batch 16 | 17 | ref: smplx/lbs 18 | batch_rodrigues: axis angle -> matrix 19 | # 20 | ''' 21 | pi = torch.Tensor([3.14159265358979323846]) 22 | 23 | def rad2deg(tensor): 24 | """Function that converts angles from radians to degrees. 25 | 26 | See :class:`~torchgeometry.RadToDeg` for details. 27 | 28 | Args: 29 | tensor (Tensor): Tensor of arbitrary shape. 30 | 31 | Returns: 32 | Tensor: Tensor with same shape as input. 33 | 34 | Example: 35 | >>> input = tgm.pi * torch.rand(1, 3, 3) 36 | >>> output = tgm.rad2deg(input) 37 | """ 38 | if not torch.is_tensor(tensor): 39 | raise TypeError("Input type is not a torch.Tensor. Got {}" 40 | .format(type(tensor))) 41 | 42 | return 180. * tensor / pi.to(tensor.device).type(tensor.dtype) 43 | 44 | def deg2rad(tensor): 45 | """Function that converts angles from degrees to radians. 46 | 47 | See :class:`~torchgeometry.DegToRad` for details. 48 | 49 | Args: 50 | tensor (Tensor): Tensor of arbitrary shape. 51 | 52 | Returns: 53 | Tensor: Tensor with same shape as input. 54 | 55 | Examples:: 56 | 57 | >>> input = 360. * torch.rand(1, 3, 3) 58 | >>> output = tgm.deg2rad(input) 59 | """ 60 | if not torch.is_tensor(tensor): 61 | raise TypeError("Input type is not a torch.Tensor. Got {}" 62 | .format(type(tensor))) 63 | 64 | return tensor * pi.to(tensor.device).type(tensor.dtype) / 180. 65 | 66 | ######### to quaternion 67 | def euler_to_quaternion(r): 68 | x = r[..., 0] 69 | y = r[..., 1] 70 | z = r[..., 2] 71 | 72 | z = z/2.0 73 | y = y/2.0 74 | x = x/2.0 75 | cz = torch.cos(z) 76 | sz = torch.sin(z) 77 | cy = torch.cos(y) 78 | sy = torch.sin(y) 79 | cx = torch.cos(x) 80 | sx = torch.sin(x) 81 | quaternion = torch.zeros_like(r.repeat(1,2))[..., :4].to(r.device) 82 | quaternion[..., 0] += cx*cy*cz - sx*sy*sz 83 | quaternion[..., 1] += cx*sy*sz + cy*cz*sx 84 | quaternion[..., 2] += cx*cz*sy - sx*cy*sz 85 | quaternion[..., 3] += cx*cy*sz + sx*cz*sy 86 | return quaternion 87 | 88 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): 89 | """Convert 3x4 rotation matrix to 4d quaternion vector 90 | 91 | This algorithm is based on algorithm described in 92 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 93 | 94 | Args: 95 | rotation_matrix (Tensor): the rotation matrix to convert. 96 | 97 | Return: 98 | Tensor: the rotation in quaternion 99 | 100 | Shape: 101 | - Input: :math:`(N, 3, 4)` 102 | - Output: :math:`(N, 4)` 103 | 104 | Example: 105 | >>> input = torch.rand(4, 3, 4) # Nx3x4 106 | >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 107 | """ 108 | if not torch.is_tensor(rotation_matrix): 109 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 110 | type(rotation_matrix))) 111 | 112 | if len(rotation_matrix.shape) > 3: 113 | raise ValueError( 114 | "Input size must be a three dimensional tensor. Got {}".format( 115 | rotation_matrix.shape)) 116 | # if not rotation_matrix.shape[-2:] == (3, 4): 117 | # raise ValueError( 118 | # "Input size must be a N x 3 x 4 tensor. Got {}".format( 119 | # rotation_matrix.shape)) 120 | 121 | rmat_t = torch.transpose(rotation_matrix, 1, 2) 122 | 123 | mask_d2 = rmat_t[:, 2, 2] < eps 124 | 125 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] 126 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] 127 | 128 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 129 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 130 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 131 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) 132 | t0_rep = t0.repeat(4, 1).t() 133 | 134 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 135 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 136 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 137 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) 138 | t1_rep = t1.repeat(4, 1).t() 139 | 140 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 141 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], 142 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2], 143 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) 144 | t2_rep = t2.repeat(4, 1).t() 145 | 146 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 147 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 148 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 149 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) 150 | t3_rep = t3.repeat(4, 1).t() 151 | 152 | mask_c0 = mask_d2 * mask_d0_d1.float() 153 | mask_c1 = mask_d2 * (1 - mask_d0_d1.float()) 154 | mask_c2 = (1 - mask_d2.float()) * mask_d0_nd1 155 | mask_c3 = (1 - mask_d2.float()) * (1 - mask_d0_nd1.float()) 156 | mask_c0 = mask_c0.view(-1, 1).type_as(q0) 157 | mask_c1 = mask_c1.view(-1, 1).type_as(q1) 158 | mask_c2 = mask_c2.view(-1, 1).type_as(q2) 159 | mask_c3 = mask_c3.view(-1, 1).type_as(q3) 160 | 161 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 162 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa 163 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa 164 | q *= 0.5 165 | return q 166 | 167 | # def angle_axis_to_quaternion(theta): 168 | # batch_size = theta.shape[0] 169 | # l1norm = torch.norm(theta + 1e-8, p=2, dim=1) 170 | # angle = torch.unsqueeze(l1norm, -1) 171 | # normalized = torch.div(theta, angle) 172 | # angle = angle * 0.5 173 | # v_cos = torch.cos(angle) 174 | # v_sin = torch.sin(angle) 175 | # quat = torch.cat([v_cos, v_sin * normalized], dim=1) 176 | # return quat 177 | 178 | def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor: 179 | """Convert an angle axis to a quaternion. 180 | 181 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 182 | 183 | Args: 184 | angle_axis (torch.Tensor): tensor with angle axis. 185 | 186 | Return: 187 | torch.Tensor: tensor with quaternion. 188 | 189 | Shape: 190 | - Input: :math:`(*, 3)` where `*` means, any number of dimensions 191 | - Output: :math:`(*, 4)` 192 | 193 | Example: 194 | >>> angle_axis = torch.rand(2, 4) # Nx4 195 | >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3 196 | """ 197 | if not torch.is_tensor(angle_axis): 198 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 199 | type(angle_axis))) 200 | 201 | if not angle_axis.shape[-1] == 3: 202 | raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}" 203 | .format(angle_axis.shape)) 204 | # unpack input and compute conversion 205 | a0: torch.Tensor = angle_axis[..., 0:1] 206 | a1: torch.Tensor = angle_axis[..., 1:2] 207 | a2: torch.Tensor = angle_axis[..., 2:3] 208 | theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2 209 | 210 | theta: torch.Tensor = torch.sqrt(theta_squared) 211 | half_theta: torch.Tensor = theta * 0.5 212 | 213 | mask: torch.Tensor = theta_squared > 0.0 214 | ones: torch.Tensor = torch.ones_like(half_theta) 215 | 216 | k_neg: torch.Tensor = 0.5 * ones 217 | k_pos: torch.Tensor = torch.sin(half_theta) / theta 218 | k: torch.Tensor = torch.where(mask, k_pos, k_neg) 219 | w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones) 220 | 221 | quaternion: torch.Tensor = torch.zeros_like(angle_axis) 222 | quaternion[..., 0:1] += a0 * k 223 | quaternion[..., 1:2] += a1 * k 224 | quaternion[..., 2:3] += a2 * k 225 | return torch.cat([w, quaternion], dim=-1) 226 | 227 | #### quaternion to 228 | def quaternion_to_rotation_matrix(quat): 229 | """Convert quaternion coefficients to rotation matrix. 230 | Args: 231 | quat: size = [B, 4] 4 <===>(w, x, y, z) 232 | Returns: 233 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 234 | """ 235 | norm_quat = quat 236 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) 237 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] 238 | 239 | B = quat.size(0) 240 | 241 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 242 | wx, wy, wz = w * x, w * y, w * z 243 | xy, xz, yz = x * y, x * z, y * z 244 | 245 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 246 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 247 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) 248 | return rotMat 249 | 250 | def quaternion_to_angle_axis(quaternion: torch.Tensor): 251 | """Convert quaternion vector to angle axis of rotation. TODO: CORRECT 252 | 253 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 254 | 255 | Args: 256 | quaternion (torch.Tensor): tensor with quaternions. 257 | 258 | Return: 259 | torch.Tensor: tensor with angle axis of rotation. 260 | 261 | Shape: 262 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions 263 | - Output: :math:`(*, 3)` 264 | 265 | Example: 266 | >>> quaternion = torch.rand(2, 4) # Nx4 267 | >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 268 | """ 269 | if not torch.is_tensor(quaternion): 270 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 271 | type(quaternion))) 272 | 273 | if not quaternion.shape[-1] == 4: 274 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" 275 | .format(quaternion.shape)) 276 | # unpack input and compute conversion 277 | q1: torch.Tensor = quaternion[..., 1] 278 | q2: torch.Tensor = quaternion[..., 2] 279 | q3: torch.Tensor = quaternion[..., 3] 280 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 281 | 282 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) 283 | cos_theta: torch.Tensor = quaternion[..., 0] 284 | two_theta: torch.Tensor = 2.0 * torch.where( 285 | cos_theta < 0.0, 286 | torch.atan2(-sin_theta, -cos_theta), 287 | torch.atan2(sin_theta, cos_theta)) 288 | 289 | k_pos: torch.Tensor = two_theta / sin_theta 290 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(quaternion.device) 291 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) 292 | 293 | angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(quaternion.device)[..., :3] 294 | angle_axis[..., 0] += q1 * k 295 | angle_axis[..., 1] += q2 * k 296 | angle_axis[..., 2] += q3 * k 297 | return angle_axis 298 | 299 | #### batch converter 300 | def batch_euler2axis(r): 301 | return quaternion_to_angle_axis(euler_to_quaternion(r)) 302 | 303 | def batch_euler2matrix(r): 304 | return quaternion_to_rotation_matrix(euler_to_quaternion(r)) 305 | 306 | def batch_matrix2euler(rot_mats): 307 | # Calculates rotation matrix to euler angles 308 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 309 | ### only y? 310 | # TODO: 311 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 312 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 313 | return torch.atan2(-rot_mats[:, 2, 0], sy) 314 | 315 | def batch_matrix2axis(rot_mats): 316 | return quaternion_to_angle_axis(rotation_matrix_to_quaternion(rot_mats)) 317 | 318 | def batch_axis2matrix(theta): 319 | # angle axis to rotation matrix 320 | # theta N x 3 321 | # return quat2mat(quat) 322 | # batch_rodrigues 323 | return quaternion_to_rotation_matrix(angle_axis_to_quaternion(theta)) 324 | 325 | def batch_axis2euler(theta): 326 | return batch_matrix2euler(batch_axis2matrix(theta)) 327 | 328 | def batch_axis2euler(r): 329 | return rot_mat_to_euler(batch_rodrigues(r)) 330 | 331 | 332 | def batch_orth_proj(X, camera): 333 | ''' 334 | X is N x num_pquaternion_to_angle_axisoints x 3 335 | ''' 336 | camera = camera.clone().view(-1, 1, 3) 337 | X_trans = X[:, :, :2] + camera[:, :, 1:] 338 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2) 339 | Xn = (camera[:, :, 0:1] * X_trans) 340 | return Xn 341 | 342 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 343 | ''' same as batch_matrix2axis 344 | Calculates the rotation matrices for a batch of rotation vectors 345 | Parameters 346 | ---------- 347 | rot_vecs: torch.tensor Nx3 348 | array of N axis-angle vectors 349 | Returns 350 | ------- 351 | R: torch.tensor Nx3x3 352 | The rotation matrices for the given axis-angle parameters 353 | ''' 354 | 355 | batch_size = rot_vecs.shape[0] 356 | device = rot_vecs.device 357 | 358 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 359 | rot_dir = rot_vecs / angle 360 | 361 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 362 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 363 | 364 | # Bx1 arrays 365 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 366 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 367 | 368 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 369 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 370 | .view((batch_size, 3, 3)) 371 | 372 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 373 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 374 | return rot_mat 375 | -------------------------------------------------------------------------------- /external/spectre/src/utils/tensor_cropper.py: -------------------------------------------------------------------------------- 1 | ''' 2 | crop 3 | for torch tensor 4 | Given image, bbox(center, bboxsize) 5 | return: cropped image, tform(used for transform the keypoint accordingly) 6 | only support crop to squared images 7 | ''' 8 | import torch 9 | from kornia.geometry.transform.imgwarp import ( 10 | warp_perspective, get_perspective_transform, warp_affine 11 | ) 12 | 13 | def points2bbox(points, points_scale=None): 14 | if points_scale: 15 | assert points_scale[0]==points_scale[1] 16 | points = points.clone() 17 | points[:,:,:2] = (points[:,:,:2]*0.5 + 0.5)*points_scale[0] 18 | min_coords, _ = torch.min(points, dim=1) 19 | xmin, ymin = min_coords[:, 0], min_coords[:, 1] 20 | max_coords, _ = torch.max(points, dim=1) 21 | xmax, ymax = max_coords[:, 0], max_coords[:, 1] 22 | center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5 23 | 24 | width = (xmax - xmin) 25 | height = (ymax - ymin) 26 | # Convert the bounding box to a square box 27 | size = torch.max(width, height).unsqueeze(-1) 28 | return center, size 29 | 30 | def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.): 31 | batch_size = center.shape[0] 32 | trans_scale = (torch.rand([batch_size, 2], device=center.device)*2. -1.) * trans_scale 33 | center = center + trans_scale*bbox_size # 0.5 34 | scale = torch.rand([batch_size,1], device=center.device) * (scale[1] - scale[0]) + scale[0] 35 | size = bbox_size*scale 36 | return center, size 37 | 38 | def crop_tensor(image, center, bbox_size, crop_size, interpolation = 'bilinear', align_corners=False): 39 | ''' for batch image 40 | Args: 41 | image (torch.Tensor): the reference tensor of shape BXHxWXC. 42 | center: [bz, 2] 43 | bboxsize: [bz, 1] 44 | crop_size; 45 | interpolation (str): Interpolation flag. Default: 'bilinear'. 46 | align_corners (bool): mode for grid_generation. Default: False. See 47 | https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details 48 | Returns: 49 | cropped_image 50 | tform 51 | ''' 52 | dtype = image.dtype 53 | device = image.device 54 | batch_size = image.shape[0] 55 | # points: top-left, top-right, bottom-right, bottom-left 56 | src_pts = torch.zeros([4,2], dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, -1).contiguous() 57 | 58 | src_pts[:, 0, :] = center - bbox_size*0.5 # / (self.crop_size - 1) 59 | src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5 60 | src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5 61 | src_pts[:, 2, :] = center + bbox_size * 0.5 62 | src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5 63 | src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5 64 | 65 | DST_PTS = torch.tensor([[ 66 | [0, 0], 67 | [crop_size - 1, 0], 68 | [crop_size - 1, crop_size - 1], 69 | [0, crop_size - 1], 70 | ]], dtype=dtype, device=device).expand(batch_size, -1, -1) 71 | # estimate transformation between points 72 | dst_trans_src = get_perspective_transform(src_pts, DST_PTS) 73 | # simulate broadcasting 74 | # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1) 75 | 76 | # warp images 77 | cropped_image = warp_affine( 78 | image, dst_trans_src[:, :2, :], (crop_size, crop_size), 79 | flags=interpolation, align_corners=align_corners) 80 | 81 | tform = torch.transpose(dst_trans_src, 2, 1) 82 | # tform = torch.inverse(dst_trans_src) 83 | return cropped_image, tform 84 | 85 | class Cropper(object): 86 | def __init__(self, crop_size, scale=[1,1], trans_scale = 0.): 87 | self.crop_size = crop_size 88 | self.scale = scale 89 | self.trans_scale = trans_scale 90 | 91 | def crop(self, image, points, points_scale=None): 92 | # points to bbox 93 | center, bbox_size = points2bbox(points.clone(), points_scale) 94 | # argument bbox. TODO: add rotation? 95 | center, bbox_size = augment_bbox(center, bbox_size, scale=self.scale, trans_scale=self.trans_scale) 96 | # crop 97 | cropped_image, tform = crop_tensor(image, center, bbox_size, self.crop_size) 98 | return cropped_image, tform 99 | 100 | def transform_points(self, points, tform, points_scale=None, normalize = True): 101 | points_2d = points[:,:,:2] 102 | 103 | #'input points must use original range' 104 | if points_scale: 105 | assert points_scale[0]==points_scale[1] 106 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0] 107 | 108 | batch_size, n_points, _ = points.shape 109 | trans_points_2d = torch.bmm( 110 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1), 111 | tform 112 | ) 113 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1) 114 | if normalize: 115 | trans_points[:,:,:2] = trans_points[:,:,:2]/self.crop_size*2 - 1 116 | return trans_points 117 | 118 | def transform_points(points, tform, points_scale=None, out_scale=None): 119 | points_2d = points[:,:,:2] 120 | 121 | #'input points must use original range' 122 | if points_scale: 123 | assert points_scale[0]==points_scale[1] 124 | points_2d = (points_2d*0.5 + 0.5)*points_scale[0] 125 | # import ipdb; ipdb.set_trace() 126 | 127 | batch_size, n_points, _ = points.shape 128 | trans_points_2d = torch.bmm( 129 | torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1), 130 | tform 131 | ) 132 | if out_scale: # h,w of output image size 133 | trans_points_2d[:,:,0] = trans_points_2d[:,:,0]/out_scale[1]*2 - 1 134 | trans_points_2d[:,:,1] = trans_points_2d[:,:,1]/out_scale[0]*2 - 1 135 | trans_points = torch.cat([trans_points_2d[:,:,:2], points[:,:,2:]], dim=-1) 136 | return trans_points -------------------------------------------------------------------------------- /external/spectre/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/external/spectre/utils/__init__.py -------------------------------------------------------------------------------- /external/spectre/utils/extract_frames_LRS3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import numpy as np 5 | import torch 6 | from argparse import ArgumentParser 7 | 8 | import sys 9 | 10 | 11 | 12 | def extract(video, tmpl='%06d.jpg'): 13 | os.makedirs(video.replace(".mp4", ""),exist_ok=True) 14 | cmd = 'ffmpeg -i \"{}\" -threads 1 -q:v 0 \"{}/%06d.jpg\"'.format(video, 15 | video.replace(".mp4", "")) 16 | os.system(cmd) 17 | 18 | # os.system("ffmpeg -i {} {} -y".format(videopath, videopath.replace(".mp4",".wav"))) 19 | 20 | 21 | # -*- coding: utf-8 -*- 22 | 23 | import os, sys 24 | import cv2 25 | import numpy as np 26 | from time import time 27 | from scipy.io import savemat 28 | import argparse 29 | from tqdm import tqdm 30 | import torch 31 | 32 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 33 | from decalib.deca import DECA 34 | from decalib.datasets import datasets 35 | from decalib.utils import util 36 | from decalib.utils.config import cfg as deca_cfg 37 | import pickle 38 | 39 | 40 | 41 | def video2sequence(video_path, videofolder): 42 | os.makedirs(videofolder, exist_ok=True) 43 | video_name = os.path.splitext(os.path.split(video_path)[-1])[0] 44 | vidcap = cv2.VideoCapture(video_path) 45 | success,image = vidcap.read() 46 | count = 0 47 | imagepath_list = [] 48 | while success: 49 | imagepath = os.path.join(videofolder, f'{video_name}_frame{count:05d}.jpg') 50 | cv2.imwrite(imagepath, image) # save frame as JPEG file 51 | success,image = vidcap.read() 52 | count += 1 53 | imagepath_list.append(imagepath) 54 | print('video frames are stored in {}'.format(videofolder)) 55 | return imagepath_list 56 | 57 | 58 | from multiprocessing import Pool 59 | from tqdm import tqdm 60 | 61 | def main(): 62 | # Parse command-line arguments 63 | parser = ArgumentParser() 64 | 65 | root = "/gpu-data3/filby/LRS3/pretrain" 66 | 67 | 68 | l = list(os.listdir("/gpu-data3/filby/LRS3/pretrain")) 69 | test_list = [] 70 | for folder in l: 71 | for file in os.listdir(os.path.join("/gpu-data3/filby/LRS3/pretrain",folder)): 72 | 73 | if file.endswith(".txt"): 74 | test_list.append([os.path.join("/gpu-data3/filby/LRS3/pretrain",folder,file.replace(".txt",".mp4")),os.path.join("/gpu-data3/filby/LRS3/pretrain",folder,file.replace(".txt",".mp4"))]) 75 | 76 | # print(test_list[0]) 77 | extract(test_list[0]) 78 | raise 79 | p = Pool(12) 80 | 81 | for _ in tqdm(p.imap_unordered(video2sequence, test_list), total=len(test_list)): 82 | pass 83 | 84 | 85 | main() 86 | 87 | # import os 88 | # import cv2 89 | # import time 90 | # import numpy as np 91 | # import torch 92 | # from argparse import ArgumentParser 93 | # 94 | # import sys 95 | # sys.path.append("face_parsing") 96 | # 97 | # 98 | # def extract_wav(videopath): 99 | # # print(videopath) 100 | # 101 | # os.system("ffmpeg -i {} {} -y".format(videopath, videopath.replace("/videos/","/wavs/").replace(".mp4",".wav"))) 102 | # 103 | # from multiprocessing import Pool 104 | # from tqdm import tqdm 105 | # 106 | # def main(): 107 | # # Parse command-line arguments 108 | # parser = ArgumentParser() 109 | # 110 | # root = "/gpu-data3/filby/MEAD/rendered/train/MEAD/videos" 111 | # 112 | # p = Pool(20) 113 | # 114 | # test_list = [] 115 | # for file in os.listdir(root): 116 | # test_list.append(os.path.join(root,file)) 117 | # 118 | # # print(test_list) 119 | # # extract_wav(test_list[0]) 120 | # for _ in tqdm(p.imap_unordered(extract_wav, test_list), total=len(test_list)): 121 | # pass 122 | # 123 | # 124 | # main() -------------------------------------------------------------------------------- /external/spectre/utils/extract_frames_and_audio.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys 4 | import cv2 5 | import argparse 6 | from tqdm import tqdm 7 | from multiprocessing import Pool 8 | 9 | def video2sequence(video_path): 10 | videofolder = os.path.splitext(video_path)[0] 11 | os.makedirs(videofolder, exist_ok=True) 12 | vidcap = cv2.VideoCapture(video_path) 13 | success,image = vidcap.read() 14 | count = 0 15 | imagepath_list = [] 16 | while success: 17 | imagepath = os.path.join(videofolder, f'%06d.jpg'%count) 18 | cv2.imwrite(imagepath, image) # save frame as JPEG file 19 | success,image = vidcap.read() 20 | count += 1 21 | imagepath_list.append(imagepath) 22 | print('video frames are stored in {}'.format(videofolder)) 23 | return videofolder 24 | 25 | 26 | def extract_audio(video_path): 27 | os.system("ffmpeg -i {} {} -y".format(video_path, video_path.replace(".mp4",".wav"))) 28 | 29 | 30 | def main(args): 31 | video_list = [] 32 | 33 | for mode in ["trainval","test"]: 34 | for folder in os.listdir(os.path.join(args.dataset_path,mode)): 35 | for file in os.listdir(os.path.join(args.dataset_path,mode,folder)): 36 | if file.endswith(".mp4"): 37 | video_list.append(os.path.join(args.dataset_path,mode,folder,file)) 38 | 39 | p = Pool(12) 40 | 41 | for _ in tqdm(p.imap_unordered(video2sequence, video_list), total=len(video_list)): 42 | pass 43 | 44 | for _ in tqdm(p.imap_unordered(extract_audio, video_list), total=len(video_list)): 45 | pass 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | 51 | parser.add_argument('--dataset_path', default='./data/LRS3', type=str, help='path to dataset') 52 | main(parser.parse_args()) -------------------------------------------------------------------------------- /external/spectre/utils/extract_wavs_LRS3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import numpy as np 5 | import torch 6 | from argparse import ArgumentParser 7 | 8 | import sys 9 | sys.path.append("face_parsing") 10 | 11 | 12 | def extract_wav(videopath): 13 | print(videopath) 14 | 15 | os.system("ffmpeg -i {} {} -y".format(videopath, videopath.replace(".mp4",".wav"))) 16 | 17 | from multiprocessing import Pool 18 | from tqdm import tqdm 19 | 20 | def main(): 21 | # Parse command-line arguments 22 | parser = ArgumentParser() 23 | 24 | root = "/raid/gretsinas/LRS3/test" 25 | 26 | p = Pool(12) 27 | 28 | l = list(os.listdir("/raid/gretsinas/LRS3/test")) 29 | test_list = [] 30 | for folder in l: 31 | for file in os.listdir(os.path.join("/raid/gretsinas/LRS3/test",folder)): 32 | 33 | if file.endswith(".txt"): 34 | test_list.append(os.path.join("/raid/gretsinas/LRS3/test",folder,file.replace(".txt",".mp4"))) 35 | 36 | # print(test_list) 37 | # extract_wav(test_list[0]) 38 | for _ in tqdm(p.imap_unordered(extract_wav, test_list), total=len(test_list)): 39 | pass 40 | 41 | 42 | main() 43 | 44 | # import os 45 | # import cv2 46 | # import time 47 | # import numpy as np 48 | # import torch 49 | # from argparse import ArgumentParser 50 | # 51 | # import sys 52 | # sys.path.append("face_parsing") 53 | # 54 | # 55 | # def extract_wav(videopath): 56 | # # print(videopath) 57 | # 58 | # os.system("ffmpeg -i {} {} -y".format(videopath, videopath.replace("/videos/","/wavs/").replace(".mp4",".wav"))) 59 | # 60 | # from multiprocessing import Pool 61 | # from tqdm import tqdm 62 | # 63 | # def main(): 64 | # # Parse command-line arguments 65 | # parser = ArgumentParser() 66 | # 67 | # root = "/gpu-data3/filby/MEAD/rendered/train/MEAD/videos" 68 | # 69 | # p = Pool(20) 70 | # 71 | # test_list = [] 72 | # for file in os.listdir(root): 73 | # test_list.append(os.path.join(root,file)) 74 | # 75 | # # print(test_list) 76 | # # extract_wav(test_list[0]) 77 | # for _ in tqdm(p.imap_unordered(extract_wav, test_list), total=len(test_list)): 78 | # pass 79 | # 80 | # 81 | # main() -------------------------------------------------------------------------------- /external/spectre/utils/lipread_utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import torch 6 | from phonemizer.backend import EspeakBackend 7 | from phonemizer.separator import Separator 8 | separator = Separator(phone='-', word=' ') 9 | backend = EspeakBackend('en-us', words_mismatch='ignore', with_stress=False) 10 | import cv2 11 | 12 | # phonemes to visemes map. this was created using Amazon Polly 13 | # https://docs.aws.amazon.com/polly/latest/dg/polly-dg.pdf 14 | 15 | def get_phoneme_to_viseme_map(): 16 | pho2vi = {} 17 | # pho2vi_counts = {} 18 | all_vis = [] 19 | 20 | p2v = "data/phonemes2visemes.csv" 21 | 22 | with open(p2v) as file: 23 | lines = file.readlines() 24 | # for line in lines[2:29]+lines[30:50]: 25 | for line in lines: 26 | if line.split(",")[0] in pho2vi: 27 | if line.split(",")[4].strip() != pho2vi[line.split(",")[0]]: 28 | print('error') 29 | pho2vi[line.split(",")[0]] = line.split(",")[4].strip() 30 | 31 | all_vis.append(line.split(",")[4].strip()) 32 | # pho2vi_counts[line.split(",")[0]] = 0 33 | return pho2vi, all_vis 34 | 35 | pho2vi, all_vis = get_phoneme_to_viseme_map() 36 | 37 | def convert_text_to_visemes(text): 38 | phonemized = backend.phonemize([text], separator=separator)[0] 39 | 40 | text = "" 41 | for word in phonemized.split(" "): 42 | visemized = [] 43 | for phoneme in word.split("-"): 44 | if phoneme == "": 45 | continue 46 | try: 47 | visemized.append(pho2vi[phoneme.strip()]) 48 | if pho2vi[phoneme.strip()] not in all_vis: 49 | all_vis.append(pho2vi[phoneme.strip()]) 50 | # pho2vi_counts[phoneme.strip()] += 1 51 | except: 52 | print('Count not find', phoneme) 53 | continue 54 | text += " " + "".join(visemized) 55 | return text 56 | 57 | 58 | 59 | def save2avi(filename, data=None, fps=25): 60 | """save2avi. - function taken from Visual Speech Recognition repository 61 | 62 | :param filename: str, the filename to save the video (.avi). 63 | :param data: numpy.ndarray, the data to be saved. 64 | :param fps: the chosen frames per second. 65 | """ 66 | assert data is not None, "data is {}".format(data) 67 | os.makedirs(os.path.dirname(filename), exist_ok=True) 68 | fourcc = cv2.VideoWriter_fourcc("F", "F", "V", "1") 69 | writer = cv2.VideoWriter(filename, fourcc, fps, (data[0].shape[1], data[0].shape[0]), 0) 70 | for frame in data: 71 | writer.write(frame) 72 | writer.release() 73 | 74 | 75 | def predict_text(lipreader, mouth_sequence): 76 | from external.Visual_Speech_Recognition_for_Multiple_Languages.espnet.asr.asr_utils import add_results_to_json 77 | lipreader.model.eval() 78 | with torch.no_grad(): 79 | enc_feats, _ = lipreader.model.encoder(mouth_sequence, None) 80 | enc_feats = enc_feats.squeeze(0) 81 | 82 | nbest_hyps = lipreader.beam_search( 83 | x=enc_feats, 84 | maxlenratio=lipreader.maxlenratio, 85 | minlenratio=lipreader.minlenratio 86 | ) 87 | nbest_hyps = [ 88 | h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), lipreader.nbest)] 89 | ] 90 | 91 | transcription = add_results_to_json(nbest_hyps, lipreader.char_list) 92 | 93 | return transcription.replace("", "") 94 | 95 | def predict_text_deca(lipreader, mouth_sequence): 96 | from external.Visual_Speech_Recognition_for_Multiple_Languages.espnet.asr.asr_utils import add_results_to_json 97 | lipreader.model.eval() 98 | with torch.no_grad(): 99 | enc_feats, _ = lipreader.model.encoder(mouth_sequence, None) 100 | enc_feats = enc_feats.squeeze(0) 101 | 102 | ys_hat = lipreader.model.ctc.ctc_lo(enc_feats) 103 | # print(ys_hat) 104 | ys_hat = ys_hat.argmax(1) 105 | ys_hat = torch.unique_consecutive(ys_hat, dim=-1) 106 | 107 | ys = [lipreader.model.args.char_list[x] for x in ys_hat if x != 0] 108 | 109 | ys = "".join(ys) 110 | ys = ys.replace("", " ") 111 | 112 | return ys.replace("", "") 113 | 114 | 115 | -------------------------------------------------------------------------------- /external/spectre/utils/run_av_hubert.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from fairseq import checkpoint_utils, tasks, utils 3 | import os 4 | from fairseq.dataclass.configs import GenerationConfig 5 | from utils.lipread_utils import convert_text_to_visemes 6 | from jiwer import wer, cer 7 | 8 | # WARNING 9 | # Run this file with additional command line arguments e.g. python apply_lip_read.py test test due to something stupid by fairseq 10 | 11 | 12 | class AverageMeter(object): 13 | """Computes and stores the average and current value""" 14 | 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | def run_lipreading(videos, transcriptions): 32 | """ 33 | :param videos: list of videos 34 | :param transcriptions: list of transcriptions 35 | :return: 36 | """ 37 | ckpt_path = "../av_hubert/data/self_large_vox_433h.pt" # download this from https://facebookresearch.github.io/av_hubert/ 38 | 39 | utils.import_user_module(Namespace(user_dir='external/av_hubert/avhubert')) 40 | 41 | modalities = ["video"] 42 | gen_subset = "test" 43 | gen_cfg = GenerationConfig(beam=1) 44 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) 45 | models = [model.eval().cuda() for model in models] 46 | saved_cfg.task.modalities = modalities 47 | 48 | import cv2,tempfile 49 | 50 | total_wer = AverageMeter() 51 | total_cer = AverageMeter() 52 | total_werv = AverageMeter() 53 | total_cerv = AverageMeter() 54 | 55 | for idx,video_path in enumerate(videos): 56 | 57 | num_frames = int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT)) 58 | data_dir = tempfile.mkdtemp() 59 | tsv_cont = ["/\n", f"test-0\t{video_path}\t{None}\t{num_frames}\t{int(16_000*num_frames/25)}\n"] 60 | label_cont = ["DUMMY\n"] 61 | with open(f"{data_dir}/test.tsv", "w") as fo: 62 | fo.write("".join(tsv_cont)) 63 | with open(f"{data_dir}/test.wrd", "w") as fo: 64 | fo.write("".join(label_cont)) 65 | saved_cfg.task.data = data_dir 66 | saved_cfg.task.label_dir = data_dir 67 | task = tasks.setup_task(saved_cfg.task) 68 | task.load_dataset(gen_subset, task_cfg=saved_cfg.task) 69 | generator = task.build_generator(models, gen_cfg) 70 | 71 | def decode_fn(x): 72 | dictionary = task.target_dictionary 73 | symbols_ignore = generator.symbols_to_strip_from_output 74 | symbols_ignore.add(dictionary.pad()) 75 | return task.datasets[gen_subset].label_processors[0].decode(x, symbols_ignore) 76 | 77 | itr = task.get_batch_iterator(dataset=task.dataset(gen_subset)).next_epoch_itr(shuffle=False) 78 | sample = next(itr) 79 | sample = utils.move_to_cuda(sample) 80 | hypos = task.inference_step(generator, models, sample) 81 | hypo = hypos[0][0]['tokens'].int().cpu() 82 | hypo = decode_fn(hypo).upper() 83 | 84 | 85 | groundtruth = transcriptions[idx].upper() 86 | 87 | 88 | w = wer(groundtruth, hypo) 89 | c = cer(groundtruth, hypo) 90 | 91 | 92 | # ---------- convert to visemes -------- # 93 | vg = convert_text_to_visemes(groundtruth) 94 | v = convert_text_to_visemes(hypo) 95 | print(hypo) 96 | print(groundtruth) 97 | print(v) 98 | print(vg) 99 | # -------------------------------------- # 100 | wv = wer(vg, v) 101 | cv = cer(vg, v) 102 | 103 | total_wer.update(w) 104 | total_cer.update(c) 105 | total_werv.update(wv) 106 | total_cerv.update(cv) 107 | 108 | print( 109 | f"progress: {idx + 1}/{len(videos)}\tcur WER: {total_wer.val * 100:.1f}\t" 110 | f"cur CER: {total_cer.val * 100:.1f}\t" 111 | f"count: {total_cer.count}\t" 112 | f"avg WER: {total_wer.avg * 100:.1f}\tavg CER: {total_cer.avg * 100:.1f}\t" 113 | f"avg WERV: {total_werv.avg * 100:.1f}\tavg CERV: {total_cerv.avg * 100:.1f}" 114 | ) 115 | 116 | 117 | import glob 118 | if __name__ == "__main__": 119 | import argparse 120 | 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--videos", type=str, required=True, help="path to videos (regex style)") 123 | parser.add_argument("--LRS3_path", type=str, default="/gpu-data3/filby/LRS3", help="path to LRS3") 124 | 125 | args = parser.parse_args() 126 | 127 | video_list = glob.glob(args.videos) 128 | 129 | assert len(video_list) > 0, "No videos found" 130 | 131 | transcriptions = [] 132 | print('Found {} videos'.format(len(video_list))) 133 | 134 | # LRS3 135 | for video in video_list: 136 | video_name = os.path.basename(video).replace("_mouth","") 137 | # print(video_name) 138 | subject = video_name.split("_")[1] 139 | clip = video_name.split("_")[2].replace(".avi", ".txt") 140 | 141 | text = open(os.path.join(f"/{args.LRS3_path}/test/", subject,clip)).readlines()[0].replace("Text:","").strip() 142 | transcriptions.append(text) 143 | 144 | # if running on TCDTIMIT uncomment the following: 145 | # for video in video_list: 146 | # video_name = os.path.basename(video).replace("_mouth", "") 147 | # # print(video_name) 148 | # subject = video_name.split("_")[0] 149 | # clip = video_name.split(".")[0].split("_")[1].upper() + ".txt" 150 | # 151 | # text = open(os.path.join(f"/gpu-data3/filby/EAVTTS/TCDTIMITprocessing/downloadTCDTIMIT/volunteers", subject, 'Clips', 'straightcam', clip)).readlines() 152 | # text = " ".join([x.split()[2].strip() for x in text]) 153 | # 154 | # transcriptions.append(text) 155 | 156 | # if running on MEAD uncomment the following: 157 | # gt = open("data/list_full_mead_annotated.txt").readlines() 158 | # gt_dic = {} 159 | # for line in gt: 160 | # gt_dic[line.split()[0]] = " ".join(line.split()[1:]) 161 | # for video in video_list: 162 | # video_name = os.path.basename(video).replace("_mouth", "") 163 | # # print(video_name) 164 | # subject = video_name.split("_")[0] 165 | # clip = video_name.split(".")[0].split("_")[1].upper() + ".txt" 166 | # 167 | # text = gt_dic[video_name.split(".")[0].replace("_mouth","")] 168 | # 169 | # transcriptions.append(text) 170 | 171 | run_lipreading(video_list, transcriptions) 172 | 173 | -------------------------------------------------------------------------------- /external/spectre/visual_mesh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import random 5 | import math 6 | import numpy as np 7 | import cv2 8 | import torch 9 | import torchvision 10 | from tqdm import tqdm 11 | from .config import cfg as spectre_cfg 12 | from .src.spectre import SPECTRE 13 | 14 | 15 | class VisualMesh(): 16 | def __init__(self, cfg) -> None: 17 | # model 18 | self.cfg = cfg 19 | self.device = cfg['trainer']['device'] 20 | spectre_cfg.pretrained_modelpath = "external/spectre/pretrained/spectre_model.tar" 21 | spectre_cfg.model.use_tex = False 22 | self.spectre = SPECTRE(spectre_cfg, device=self.device) 23 | self.spectre.eval() 24 | 25 | def forward(self, exp_out, exp, data_info): 26 | 'input: expression coefficients' 27 | 'output: mesh' 28 | n = self.cfg['trainer']['visual_images'] 29 | if self.cfg['datasets']['dataset']=='mead': 30 | # template 31 | codedict = {} 32 | codedict['pose'] = torch.zeros((n*3, 6), dtype=torch.float).to(self.device) 33 | codedict['exp'] = torch.zeros((n*3, 50), dtype=torch.float).to(self.device) 34 | codedict['shape'] = torch.zeros((n*3, 100), dtype=torch.float).to(self.device) 35 | codedict['tex'] = torch.zeros((n*3, 50), dtype=torch.float).to(self.device) 36 | codedict['cam'] = torch.zeros((n*3, 3), dtype=torch.float).to(self.device) 37 | self.codedict = codedict 38 | # true coefficients 39 | coefficient_path = os.path.join(data_info, 'crop_head_info.npy') 40 | coefficient_info = np.load(coefficient_path, allow_pickle=True).item()['face3d_encode'] 41 | coefficients = get_coefficients(coefficient_info) 42 | for key in coefficients: 43 | coefficients[key] = torch.FloatTensor(coefficients[key]).to(self.device) 44 | start_vis = random.randint(0,exp.shape[1]-1-n) # 起始帧 45 | 46 | self.codedict['exp'][0:n] = exp_out[0, start_vis: start_vis+n,:-3] # 生成的参数在平均脸上 47 | self.codedict['exp'][n:2*n] = exp[0, start_vis: start_vis+n,:-3] # ground_truth参数在平均脸上 48 | self.codedict['exp'][2*n:3*n] = coefficients['exp'][start_vis: start_vis+n,:] # ground_truth在原脸上 49 | 50 | self.codedict['pose'][0:n, 3:] = exp_out[0, start_vis: start_vis+n,-3:] # jaw pose 51 | self.codedict['pose'][n:2*n, 3:] = exp[0, start_vis: start_vis+n,-3:] 52 | 53 | self.codedict['cam'][0:n] = coefficients['cam'][start_vis: start_vis+n, :] # 取n帧的cam 54 | self.codedict['cam'][n:2*n] = coefficients['cam'][start_vis: start_vis+n, :] 55 | self.codedict['cam'][2*n:3*n] = coefficients['cam'][start_vis: start_vis+n, :] 56 | 57 | self.codedict['pose'][2*n:3*n] = coefficients['pose'][start_vis: start_vis+n, :] # 取n帧的pose 58 | self.codedict['shape'][2*n:3*n] = coefficients['shape'][start_vis: start_vis+n, :] # # 取n帧的shape 59 | 60 | elif self.cfg['datasets']['dataset']=='mote': 61 | # template 62 | codedict = {} 63 | codedict['pose'] = torch.zeros((n*2, 6), dtype=torch.float).to(self.device) 64 | codedict['exp'] = torch.zeros((n*2, 50), dtype=torch.float).to(self.device) 65 | codedict['shape'] = torch.zeros((n*2, 100), dtype=torch.float).to(self.device) 66 | codedict['tex'] = torch.zeros((n*2, 50), dtype=torch.float).to(self.device) 67 | codedict['cam'] = torch.zeros((n*2, 3), dtype=torch.float).to(self.device) 68 | self.codedict = codedict 69 | # true coefficients 70 | # coefficient_path = os.path.join(self.cfg['datasets']['data_root'], data_info[0][0], data_info[1][0], 'train1_all.npz') 71 | # coefficient_info = np.load(coefficient_path, allow_pickle=True)['face'][-1*self.cfg['datasets']['eval_frames']:, :] 72 | # coefficients = get_coefficients(coefficient_info) 73 | # for key in coefficients: 74 | # coefficients[key] = torch.FloatTensor(coefficients[key]).to(self.device) 75 | start_vis = random.randint(0,exp.shape[1]-1-n) # 起始帧 76 | 77 | self.codedict['exp'][0:n] = exp_out[0, start_vis: start_vis+n,:-3] # 生成的参数在平均脸上 78 | self.codedict['exp'][n:2*n] = exp[0, start_vis: start_vis+n,:-3] # ground_truth参数在平均脸上 79 | 80 | self.codedict['pose'][0:n, 3:] = exp_out[0, start_vis: start_vis+n,-3:] # jaw pose 81 | self.codedict['pose'][n:2*n, 3:] = exp[0, start_vis: start_vis+n,-3:] 82 | 83 | cam = torch.tensor([8.8093824, 0.00314824, 0.043486204]).unsqueeze(0).repeat(n, 1) # cam 84 | self.codedict['cam'][0:n] = cam 85 | self.codedict['cam'][n:2*n] = cam 86 | 87 | opdict = self.spectre.decode(self.codedict, rendering=True, vis_lmk=False, return_vis=False) 88 | # rendered_images = torchvision.utils.make_grid(opdict['rendered_images'].detach().cpu(), nrow=n) 89 | return opdict['rendered_images'] 90 | 91 | def infer(self, exp, cfg, exp_gt=None, render_batch=100): 92 | 'input: expression coefficients' 93 | 'output: mesh' 94 | n = exp.shape[1] 95 | if self.cfg['datasets']['dataset']=='mead': 96 | coefficient_path = os.path.join(cfg['datasets']['data_root'], cfg['test']['audio_path']).replace('audio.wav', 'crop_head_info.npy') 97 | coefficient_info = np.load(coefficient_path, allow_pickle=True).item()['face3d_encode'] 98 | coefficients = get_coefficients(coefficient_info) 99 | assert exp.shape[1]==coefficients['exp'].shape[0] 100 | coefficients_pred = copy.deepcopy(coefficients) 101 | for key in coefficients: 102 | coefficients[key] = torch.FloatTensor(torch.from_numpy(coefficients[key])).to(self.device) 103 | coefficients_pred[key] = torch.FloatTensor(torch.from_numpy(coefficients_pred[key])).to(self.device) 104 | if key == 'exp': 105 | coefficients_pred[key] = exp[0][:, :-3] 106 | elif key == 'pose': 107 | coefficients_pred[key][:, -3:] = exp[0][:, -3:] 108 | elif self.cfg['datasets']['dataset']=='mote': 109 | coefficients = {} 110 | coefficients['pose'] = torch.zeros((n, 6), dtype=torch.float).to(self.device) 111 | coefficients['exp'] = torch.zeros((n, 50), dtype=torch.float).to(self.device) 112 | coefficients['shape'] = torch.zeros((n, 100), dtype=torch.float).to(self.device) 113 | coefficients['tex'] = torch.zeros((n, 50), dtype=torch.float).to(self.device) 114 | coefficients['cam'] = torch.zeros((n, 3), dtype=torch.float).to(self.device) 115 | coefficients_pred = copy.deepcopy(coefficients) 116 | # cam = torch.tensor([8.8093824, 0.00314824, 0.043486204]).unsqueeze(0).repeat(n, 1).to(self.device) # cam 117 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(self.device) # cam 118 | for key in coefficients: 119 | # coefficients[key] = torch.FloatTensor(torch.from_numpy(coefficients[key])).to(self.device) 120 | # coefficients_pred[key] = torch.FloatTensor(torch.from_numpy(coefficients_pred[key])).to(self.device) 121 | if key == 'exp': 122 | if exp_gt is not None: 123 | coefficients[key] = exp_gt[0][:, :-3] 124 | coefficients_pred[key] = exp[0][:, :-3] 125 | elif key == 'pose': 126 | if exp_gt is not None: 127 | coefficients[key][:, -3:] = exp_gt[0][:, -3:] 128 | coefficients_pred[key][:, -3:] = exp[0][:, -3:] 129 | elif key == 'cam': 130 | coefficients[key] = cam[:, :] 131 | coefficients_pred[key] = cam[:, :] 132 | n_batch = int(math.ceil(n/render_batch)) 133 | rendered_images, rendered_images_pred = [], [] 134 | for i in range(n_batch): 135 | coefficients_render, coefficients_pred_render = {}, {} 136 | for k in coefficients: 137 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n) 138 | coefficients_render[k] = coefficients[k][start_f: end_f] 139 | coefficients_pred_render[k] = coefficients_pred[k][start_f: end_f] 140 | 141 | if exp_gt is not None: 142 | opdict = self.spectre.decode(coefficients_render, rendering=True, vis_lmk=False, return_vis=False) 143 | rendered_images.append(opdict['rendered_images'].detach().cpu()) 144 | opdict_pred = self.spectre.decode(coefficients_pred_render, rendering=True, vis_lmk=False, return_vis=False) 145 | rendered_images_pred.append(opdict_pred['rendered_images'].detach().cpu()) 146 | if exp_gt is not None: 147 | rendered_images_cat = torch.cat(rendered_images, dim=0) 148 | else: 149 | rendered_images_cat = None 150 | rendered_images_pred_cat = torch.cat(rendered_images_pred, dim=0) 151 | 152 | return rendered_images_cat, rendered_images_pred_cat 153 | # opdict = self.spectre.decode(coefficients, rendering=True, vis_lmk=False, return_vis=False) 154 | # opdict_pred = self.spectre.decode(coefficients_pred, rendering=True, vis_lmk=False, return_vis=False) 155 | # return opdict['rendered_images'], opdict_pred['rendered_images'] 156 | 157 | def exp2mesh(self, coefficients_info, pose0=True, render_batch=100): 158 | n = coefficients_info.shape[0] 159 | if coefficients_info.shape[-1] == 53: 160 | coefficients = {} 161 | coefficients['pose'] = torch.zeros((n, 6), dtype=torch.float).to(self.device) 162 | coefficients['exp'] = torch.zeros((n, 50), dtype=torch.float).to(self.device) 163 | coefficients['shape'] = torch.zeros((n, 100), dtype=torch.float).to(self.device) 164 | coefficients['tex'] = torch.zeros((n, 50), dtype=torch.float).to(self.device) 165 | coefficients['cam'] = torch.zeros((n, 3), dtype=torch.float).to(self.device) 166 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(self.device) # cam 167 | for key in coefficients: 168 | # coefficients[key] = torch.FloatTensor(torch.from_numpy(coefficients[key])).to(self.device) 169 | # coefficients_pred[key] = torch.FloatTensor(torch.from_numpy(coefficients_pred[key])).to(self.device) 170 | if key == 'exp': 171 | coefficients[key] = coefficients_info[:, :-3] 172 | elif key == 'pose': 173 | coefficients[key][:, -3:] = coefficients_info[:, -3:] 174 | elif key == 'cam': 175 | coefficients[key] = cam[:, :] 176 | elif coefficients_info.shape[-1] == 209 or coefficients_info.shape[-1] == 213 or coefficients_info.shape[-1] == 236: 177 | coefficients = get_coefficients(coefficients_info) 178 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(self.device) # cam 179 | for key in coefficients: 180 | coefficients[key] = torch.FloatTensor(coefficients[key]).to(self.device) 181 | if pose0: 182 | if key == 'pose': 183 | coefficients[key][:, :3] = torch.zeros_like(coefficients[key][:, :3]) 184 | elif key == 'shape' or key == 'tex': 185 | coefficients[key] = torch.zeros_like(coefficients[key]) 186 | elif key == 'cam': 187 | coefficients[key] = cam 188 | 189 | n_batch = int(math.ceil(n/render_batch)) 190 | rendered_images = [] 191 | vertices = [] 192 | for i in tqdm(range(n_batch)): 193 | coefficients_batch = {} 194 | for k in coefficients: 195 | start_f, end_f = i*render_batch, min((i+1)*render_batch, n) 196 | coefficients_batch[k] = coefficients[k][start_f: end_f] 197 | opdict = self.spectre.decode(coefficients_batch, rendering=True, vis_lmk=False, return_vis=False) 198 | rendered_images.append(opdict['rendered_images'].detach().cpu()) 199 | vertices.append(opdict['verts'].detach().cpu()) 200 | rendered_images_cat = torch.cat(rendered_images, dim=0) 201 | vertices_cat = torch.cat(vertices, dim=0) 202 | return rendered_images_cat, vertices_cat 203 | 204 | 205 | def get_coefficients(coefficient_info): 206 | coefficient_dict = {} 207 | coefficient_dict['pose'] = coefficient_info[:, :6] 208 | coefficient_dict['exp'] = coefficient_info[:, 6:56] 209 | coefficient_dict['shape'] = coefficient_info[:, 56:156] 210 | coefficient_dict['tex'] = coefficient_info[:, 156:206] 211 | coefficient_dict['cam'] = coefficient_info[:, 206:209] 212 | # coefficient_dict['light'] = coefficient_info[:, 209:236] 213 | return coefficient_dict 214 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huifu99/Mimic/9e71299fc041a232e37ac79fbb4dff0b0552c20e/framework.png -------------------------------------------------------------------------------- /models/lib/grl_module.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | import torch 3 | from torch import nn 4 | 5 | class GradientReversal(Function): 6 | @staticmethod 7 | def forward(ctx, x, alpha): 8 | ctx.save_for_backward(x, alpha) 9 | return x 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): 13 | grad_input = None 14 | _, alpha = ctx.saved_tensors 15 | if ctx.needs_input_grad[0]: 16 | grad_input = - alpha*grad_output 17 | return grad_input, None 18 | revgrad = GradientReversal.apply 19 | 20 | class GradientReversal(nn.Module): 21 | def __init__(self, alpha): 22 | super().__init__() 23 | self.alpha = torch.tensor(alpha, requires_grad=False) 24 | 25 | def forward(self, x): 26 | return revgrad(x, self.alpha) -------------------------------------------------------------------------------- /models/lib/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | 7 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 8 | ''' Sinusoid position encoding table ''' 9 | 10 | def cal_angle(position, hid_idx): 11 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 12 | 13 | def get_posi_angle_vec(position): 14 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 15 | 16 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) 17 | for pos_i in range(n_position)]) 18 | 19 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 20 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 21 | 22 | if padding_idx is not None: 23 | # zero vector for padding dimension 24 | sinusoid_table[padding_idx] = 0. 25 | return torch.FloatTensor(sinusoid_table) 26 | 27 | 28 | class Mish(nn.Module): 29 | def __init__(self): 30 | super(Mish, self).__init__() 31 | def forward(self, x): 32 | return x * torch.tanh(F.softplus(x)) 33 | 34 | 35 | class AffineLinear(nn.Module): 36 | def __init__(self, in_dim, out_dim): 37 | super(AffineLinear, self).__init__() 38 | affine = nn.Linear(in_dim, out_dim) 39 | self.affine = affine 40 | 41 | def forward(self, input): 42 | return self.affine(input) 43 | 44 | 45 | class StyleAdaptiveLayerNorm(nn.Module): 46 | def __init__(self, in_channel, style_dim): 47 | super(StyleAdaptiveLayerNorm, self).__init__() 48 | self.in_channel = in_channel 49 | self.norm = nn.LayerNorm(in_channel, elementwise_affine=False) 50 | 51 | self.style = AffineLinear(style_dim, in_channel * 2) 52 | self.style.affine.bias.data[:in_channel] = 1 53 | self.style.affine.bias.data[in_channel:] = 0 54 | 55 | def forward(self, input, style_code): 56 | # style 57 | style = self.style(style_code).unsqueeze(1) 58 | gamma, beta = style.chunk(2, dim=-1) 59 | 60 | out = self.norm(input) 61 | out = gamma * out + beta 62 | return out 63 | 64 | 65 | class LinearNorm(nn.Module): 66 | def __init__(self, 67 | in_channels, 68 | out_channels, 69 | bias=True, 70 | spectral_norm=False, 71 | ): 72 | super(LinearNorm, self).__init__() 73 | self.fc = nn.Linear(in_channels, out_channels, bias) 74 | 75 | if spectral_norm: 76 | self.fc = nn.utils.spectral_norm(self.fc) 77 | 78 | def forward(self, input): 79 | out = self.fc(input) 80 | return out 81 | 82 | 83 | class ConvNorm(nn.Module): 84 | def __init__(self, 85 | in_channels, 86 | out_channels, 87 | kernel_size=1, 88 | stride=1, 89 | padding=None, 90 | dilation=1, 91 | bias=True, 92 | spectral_norm=False, 93 | ): 94 | super(ConvNorm, self).__init__() 95 | 96 | if padding is None: 97 | assert(kernel_size % 2 == 1) 98 | padding = int(dilation * (kernel_size - 1) / 2) 99 | 100 | self.conv = torch.nn.Conv1d(in_channels, 101 | out_channels, 102 | kernel_size=kernel_size, 103 | stride=stride, 104 | padding=padding, 105 | dilation=dilation, 106 | bias=bias) 107 | 108 | if spectral_norm: 109 | self.conv = nn.utils.spectral_norm(self.conv) 110 | 111 | def forward(self, input): 112 | out = self.conv(input) 113 | return out 114 | 115 | 116 | class MultiHeadAttention(nn.Module): 117 | ''' Multi-Head Attention module ''' 118 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0., spectral_norm=False): 119 | super().__init__() 120 | 121 | self.n_head = n_head 122 | self.d_k = d_k 123 | self.d_v = d_v 124 | 125 | self.w_qs = nn.Linear(d_model, n_head * d_k) 126 | self.w_ks = nn.Linear(d_model, n_head * d_k) 127 | self.w_vs = nn.Linear(d_model, n_head * d_v) 128 | 129 | self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout) 130 | 131 | self.fc = nn.Linear(n_head * d_v, d_model) 132 | self.dropout = nn.Dropout(dropout) 133 | 134 | if spectral_norm: 135 | self.w_qs = nn.utils.spectral_norm(self.w_qs) 136 | self.w_ks = nn.utils.spectral_norm(self.w_ks) 137 | self.w_vs = nn.utils.spectral_norm(self.w_vs) 138 | self.fc = nn.utils.spectral_norm(self.fc) 139 | 140 | def forward(self, x, mask=None): 141 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 142 | sz_b, len_x, _ = x.size() 143 | 144 | residual = x 145 | 146 | q = self.w_qs(x).view(sz_b, len_x, n_head, d_k) 147 | k = self.w_ks(x).view(sz_b, len_x, n_head, d_k) 148 | v = self.w_vs(x).view(sz_b, len_x, n_head, d_v) 149 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, 150 | len_x, d_k) # (n*b) x lq x dk 151 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, 152 | len_x, d_k) # (n*b) x lk x dk 153 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, 154 | len_x, d_v) # (n*b) x lv x dv 155 | 156 | if mask is not None: 157 | slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 158 | else: 159 | slf_mask = None 160 | output, attn = self.attention(q, k, v, mask=slf_mask) 161 | 162 | output = output.view(n_head, sz_b, len_x, d_v) 163 | output = output.permute(1, 2, 0, 3).contiguous().view( 164 | sz_b, len_x, -1) # b x lq x (n*dv) 165 | 166 | output = self.fc(output) 167 | 168 | output = self.dropout(output) + residual 169 | return output, attn 170 | 171 | 172 | class ScaledDotProductAttention(nn.Module): 173 | ''' Scaled Dot-Product Attention ''' 174 | 175 | def __init__(self, temperature, dropout): 176 | super().__init__() 177 | self.temperature = temperature 178 | self.softmax = nn.Softmax(dim=2) 179 | self.dropout = nn.Dropout(dropout) 180 | 181 | def forward(self, q, k, v, mask=None): 182 | 183 | attn = torch.bmm(q, k.transpose(1, 2)) 184 | attn = attn / self.temperature 185 | 186 | if mask is not None: 187 | attn = attn.masked_fill(mask, -np.inf) 188 | 189 | attn = self.softmax(attn) 190 | p_attn = self.dropout(attn) 191 | 192 | output = torch.bmm(p_attn, v) 193 | return output, attn 194 | 195 | 196 | class Conv1dGLU(nn.Module): 197 | ''' 198 | Conv1d + GLU(Gated Linear Unit) with residual connection. 199 | For GLU refer to https://arxiv.org/abs/1612.08083 paper. 200 | ''' 201 | def __init__(self, in_channels, out_channels, kernel_size, dropout): 202 | super(Conv1dGLU, self).__init__() 203 | self.out_channels = out_channels 204 | self.conv1 = ConvNorm(in_channels, 2*out_channels, kernel_size=kernel_size) 205 | self.dropout = nn.Dropout(dropout) 206 | 207 | def forward(self, x): 208 | residual = x 209 | x = self.conv1(x) 210 | x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) 211 | x = x1 * torch.sigmoid(x2) 212 | x = residual + self.dropout(x) 213 | return x 214 | 215 | 216 | class FFTBlock(nn.Module): 217 | ''' FFT Block ''' 218 | def __init__(self, d_model,d_inner, 219 | n_head, d_k, d_v, fft_conv1d_kernel_size, style_dim, dropout): 220 | super(FFTBlock, self).__init__() 221 | self.slf_attn = MultiHeadAttention( 222 | n_head, d_model, d_k, d_v, dropout=dropout) 223 | self.saln_0 = StyleAdaptiveLayerNorm(d_model, style_dim) 224 | 225 | self.pos_ffn = PositionwiseFeedForward( 226 | d_model, d_inner, fft_conv1d_kernel_size, dropout=dropout) 227 | self.saln_1 = StyleAdaptiveLayerNorm(d_model, style_dim) 228 | 229 | def forward(self, input, style_vector, mask=None, slf_attn_mask=None): 230 | # multi-head self attn 231 | slf_attn_output, slf_attn = self.slf_attn(input, mask=slf_attn_mask) 232 | slf_attn_output = self.saln_0(slf_attn_output, style_vector) 233 | if mask is not None: 234 | slf_attn_output = slf_attn_output.masked_fill(mask.unsqueeze(-1), 0) 235 | 236 | # position wise FF 237 | output = self.pos_ffn(slf_attn_output) 238 | output = self.saln_1(output, style_vector) 239 | if mask is not None: 240 | output = output.masked_fill(mask.unsqueeze(-1), 0) 241 | 242 | return output, slf_attn 243 | 244 | 245 | class PositionwiseFeedForward(nn.Module): 246 | ''' A two-feed-forward-layer module ''' 247 | def __init__(self, d_in, d_hid, fft_conv1d_kernel_size, dropout=0.1): 248 | super().__init__() 249 | self.w_1 = ConvNorm(d_in, d_hid, kernel_size=fft_conv1d_kernel_size[0]) 250 | self.w_2 = ConvNorm(d_hid, d_in, kernel_size=fft_conv1d_kernel_size[1]) 251 | 252 | self.mish = Mish() 253 | self.dropout = nn.Dropout(dropout) 254 | 255 | def forward(self, input): 256 | residual = input 257 | 258 | output = input.transpose(1, 2) 259 | output = self.w_2(self.dropout(self.mish(self.w_1(output)))) 260 | output = output.transpose(1, 2) 261 | 262 | output = self.dropout(output) + residual 263 | return output -------------------------------------------------------------------------------- /models/lib/wav2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from transformers import Wav2Vec2Model,Wav2Vec2Config 5 | from transformers.modeling_outputs import BaseModelOutput 6 | from typing import Optional, Tuple 7 | _CONFIG_FOR_DOC = "Wav2Vec2Config" 8 | 9 | # the implementation of Wav2Vec2Model is borrowed from https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model 10 | # initialize our encoder with the pre-trained wav2vec 2.0 weights. 11 | def _compute_mask_indices( 12 | shape: Tuple[int, int], 13 | mask_prob: float, 14 | mask_length: int, 15 | attention_mask: Optional[torch.Tensor] = None, 16 | min_masks: int = 0, 17 | ) -> np.ndarray: 18 | bsz, all_sz = shape 19 | mask = np.full((bsz, all_sz), False) 20 | 21 | all_num_mask = int( 22 | mask_prob * all_sz / float(mask_length) 23 | + np.random.rand() 24 | ) 25 | all_num_mask = max(min_masks, all_num_mask) 26 | mask_idcs = [] 27 | padding_mask = attention_mask.ne(1) if attention_mask is not None else None 28 | for i in range(bsz): 29 | if padding_mask is not None: 30 | sz = all_sz - padding_mask[i].long().sum().item() 31 | num_mask = int( 32 | mask_prob * sz / float(mask_length) 33 | + np.random.rand() 34 | ) 35 | num_mask = max(min_masks, num_mask) 36 | else: 37 | sz = all_sz 38 | num_mask = all_num_mask 39 | 40 | lengths = np.full(num_mask, mask_length) 41 | 42 | if sum(lengths) == 0: 43 | lengths[0] = min(mask_length, sz - 1) 44 | 45 | min_len = min(lengths) 46 | if sz - min_len <= num_mask: 47 | min_len = sz - num_mask - 1 48 | 49 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 50 | mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) 51 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 52 | 53 | min_len = min([len(m) for m in mask_idcs]) 54 | for i, mask_idc in enumerate(mask_idcs): 55 | if len(mask_idc) > min_len: 56 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 57 | mask[i, mask_idc] = True 58 | return mask 59 | 60 | # linear interpolation layer 61 | def linear_interpolation(features, input_fps, output_fps, output_len=None): 62 | features = features.transpose(1, 2) 63 | seq_len = features.shape[2] / float(input_fps) 64 | if output_len is None: 65 | output_len = int(seq_len * output_fps)+1 66 | output_features = F.interpolate(features,size=output_len,align_corners=True,mode='linear') 67 | return output_features.transpose(1, 2) 68 | 69 | class Wav2Vec2Model(Wav2Vec2Model): 70 | def __init__(self, config): 71 | super().__init__(config) 72 | def forward( 73 | self, 74 | input_values, 75 | dataset, 76 | attention_mask=None, 77 | output_attentions=None, 78 | output_hidden_states=None, 79 | return_dict=None, 80 | frame_num=None, 81 | align='interpolation', 82 | ): 83 | self.config.output_attentions = True 84 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 85 | output_hidden_states = ( 86 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 87 | ) 88 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 89 | 90 | hidden_states = self.feature_extractor(input_values) # (B,C,L) 91 | hidden_states = hidden_states.transpose(1, 2) # (B,L,C) L=299 92 | 93 | if dataset == "BIWI": 94 | # cut audio feature 95 | if hidden_states.shape[1]%2 != 0: 96 | hidden_states = hidden_states[:, :-1] 97 | if frame_num and hidden_states.shape[1]>frame_num*2: 98 | hidden_states = hidden_states[:, :frame_num*2] 99 | elif dataset == "vocaset": 100 | # frame_num = int(hidden_states.shape[-2]*cfg.video_fps/cfg.audio_fps) 101 | hidden_states = linear_interpolation(hidden_states, 50, 30, output_len=frame_num) 102 | elif dataset == 'HDTF': 103 | # hidden_states = linear_interpolation(hidden_states, 50, 25, output_len=frame_num) 104 | if align == 'interpolation': 105 | hidden_states = linear_interpolation(hidden_states, 50, 25, output_len=frame_num) 106 | else: 107 | pass 108 | 109 | if attention_mask is not None: 110 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) 111 | attention_mask = torch.zeros( 112 | hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device 113 | ) 114 | attention_mask[ 115 | (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) 116 | ] = 1 117 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() 118 | 119 | hidden_states = self.feature_projection(hidden_states) 120 | if isinstance(hidden_states, tuple): 121 | hidden_states = hidden_states[0] 122 | 123 | if self.config.apply_spec_augment and self.training: 124 | batch_size, sequence_length, hidden_size = hidden_states.size() 125 | if self.config.mask_time_prob > 0: 126 | mask_time_indices = _compute_mask_indices( 127 | (batch_size, sequence_length), 128 | self.config.mask_time_prob, 129 | self.config.mask_time_length, 130 | attention_mask=attention_mask, 131 | min_masks=2, 132 | ) 133 | hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) 134 | if self.config.mask_feature_prob > 0: 135 | mask_feature_indices = _compute_mask_indices( 136 | (batch_size, hidden_size), 137 | self.config.mask_feature_prob, 138 | self.config.mask_feature_length, 139 | ) 140 | mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) 141 | hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 142 | encoder_outputs = self.encoder( 143 | hidden_states, 144 | attention_mask=attention_mask, 145 | output_attentions=output_attentions, 146 | output_hidden_states=output_hidden_states, 147 | return_dict=return_dict, 148 | ) 149 | hidden_states = encoder_outputs[0] 150 | if not return_dict: 151 | return (hidden_states,) + encoder_outputs[1:] 152 | 153 | return BaseModelOutput( 154 | last_hidden_state=hidden_states, 155 | hidden_states=encoder_outputs.hidden_states, 156 | attentions=encoder_outputs.attentions, 157 | ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | av==9.2.0 2 | certifi==2023.11.17 3 | chumpy==0.70 4 | Cython==3.0.8 5 | easydict==1.12 6 | einops==0.7.0 7 | face-alignment==1.4.1 8 | ffmpeg==1.4 9 | ffmpeg-python==0.2.0 10 | future==0.18.3 11 | fvcore==0.1.5.post20221221 12 | gpustat==1.1.1 13 | huggingface-hub==0.0.8 14 | imageio==2.33.1 15 | imageio-ffmpeg==0.4.9 16 | iopath==0.1.10 17 | jiwer==2.3.0 18 | keras==2.15.0 19 | kornia==0.6.6 20 | librosa==0.9.2 21 | lmdb==1.4.1 22 | loguru==0.6.0 23 | matplotlib==3.8.3 24 | mediapipe==0.8.10 25 | multiprocess==0.70.15 26 | ninja==1.11.1.1 27 | numba==0.58.1 28 | numpy==1.23.0 29 | omegaconf==2.1.2 30 | opencv-contrib-python==4.9.0.80 31 | opencv-python==4.9.0.80 32 | opencv-python-headless==4.9.0.80 33 | packaging==23.2 34 | pandas==2.2.1 35 | phonemizer==3.2.1 36 | pillow==10.2.0 37 | protobuf==3.20.3 38 | psbody-mesh==0.4 39 | pyglet==2.0.12 40 | PyOpenGL==3.1.0 41 | pyrender==0.1.45 42 | PyYAML==6.0.1 43 | scikit-image==0.19.3 44 | scikit-learn==1.3.2 45 | scipy==1.12.0 46 | six==1.16.0 47 | soundfile==0.12.1 48 | tensorboard==2.15.2 49 | tensorboard-data-server==0.7.2 50 | tensorboard-plugin-wit==1.8.1 51 | tensorflow==2.15.0.post1 52 | tokenizers==0.10.3 53 | torchfile==0.1.0 54 | torchgeometry==0.1.2 55 | torchmetrics==1.3.0 56 | tqdm==4.66.1 57 | transformers==4.6.1 58 | trimesh==3.16.4 59 | yacs==0.1.8 -------------------------------------------------------------------------------- /tools/render_spectre.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | import sys 7 | sys.path.append('.') 8 | from external.spectre.src.utils import util 9 | from external.spectre.src.utils.renderer import SRenderY, set_rasterizer 10 | 11 | 12 | class Render(nn.Module): 13 | def __init__(self, image_size=224, background=True, device='cuda:0') -> None: 14 | super().__init__() 15 | self.image_size = image_size 16 | self.background = background 17 | self.device = device 18 | set_rasterizer('pytorch3d') 19 | obj_filename='external/spectre/data/head_template.obj' 20 | self.render = SRenderY(self.image_size, obj_filename=obj_filename, uv_size=256, rasterizer_type='pytorch3d').to(self.device) 21 | 22 | def forward(self, verts): 23 | if type(verts) is np.ndarray: 24 | verts = torch.FloatTensor(verts).to(self.device) 25 | n = verts.shape[0] 26 | cam = torch.tensor([8.740263, -0.00034628902, 0.020510273]).unsqueeze(0).repeat(n, 1).to(self.device) # need try others 27 | trans_verts = util.batch_orth_proj(verts, cam); 28 | trans_verts[:, :, 1:] = -trans_verts[:, :, 1:] 29 | h, w = self.image_size, self.image_size 30 | if self.background: 31 | background = None 32 | else: # white 33 | background = torch.ones((n, 3, h, w)).to(self.device) 34 | shape_images, _, grid, alpha_images, pos_mask = self.render.render_shape(verts, trans_verts, h=h, w=w, 35 | images=background, 36 | return_grid=True, 37 | return_pos=True) 38 | rendered_images = self.postprocess(shape_images) 39 | return rendered_images 40 | 41 | def postprocess(self, rendered_images): 42 | # rendered_images (b, c, h, w) 43 | rendered_images = rendered_images.cpu().numpy() 44 | rendered_images = rendered_images*255. 45 | rendered_images = np.maximum(np.minimum(rendered_images, 255), 0) 46 | rendered_images = rendered_images.transpose(0,2,3,1)[:,:,:,[2,1,0]] 47 | rendered_images = rendered_images.astype(np.uint8).copy() 48 | return rendered_images 49 | 50 | 51 | if __name__ == '__main__': 52 | import cv2 53 | render = Render() 54 | verts_path = '/root/autodl-tmp/data/fh/VOCASET/FaceFormer_processed/vertices_npy/FaceTalk_170725_00137_TA_sentence01.npy' 55 | save_root = os.path.join('tmp', 'FaceTalk_170725_00137_TA_sentence01') 56 | os.makedirs(save_root, exist_ok=True) 57 | verts = np.load(verts_path,allow_pickle=True)[::2,:] 58 | 59 | # verts_path = '/root/autodl-tmp/data/fh/HDTF/spectre_processed_25fps_16kHz/RD_Radio1_000/verts_new_shape1.npy' 60 | # save_root = os.path.join('tmp', 'RD_Radio1_000') 61 | # os.makedirs(save_root, exist_ok=True) 62 | # verts = np.load(verts_path, allow_pickle=True).item()['verts'] 63 | verts = torch.FloatTensor(verts).to('cuda:0') 64 | for i in range(verts.shape[0]): 65 | vert = verts[i:i+1].view(-1, 5023, 3) 66 | render_img = render.forward(vert) 67 | img_save = os.path.join(save_root, str(i+1).zfill(6)+'.png') 68 | cv2.imwrite(img_save, render_img[0]) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | # set device 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # single GPU for training 4 | # os.environ['TRANSFORMERS_OFFLINE'] = '1' 5 | import datetime 6 | import numpy as np 7 | import copy 8 | import cv2 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.tensorboard import SummaryWriter 12 | from base.utilities import get_logger, get_parser, fixed_seed 13 | "model" 14 | from models.network import DisNetAutoregCycle as Model 15 | "data" 16 | from data.dataloader_HDTF import get_dataloaders 17 | "loss" 18 | from losses.loss_collections import ComposeCycleLoss as ComposeLoss 19 | "val" 20 | # from val import Val, ValTrainset 21 | 22 | 23 | def main(): 24 | fixed_seed(seed=42) 25 | cfg, params_dict = get_parser() 26 | device = 'cuda:0' 27 | output_dir = os.path.join(cfg.output, cfg.dataset+'_train{}_val-shape1'.format(str(cfg.train_ids)), cfg.exp_name) 28 | 29 | "log" 30 | logger = get_logger() 31 | log_dir = os.path.join(output_dir, 'log') 32 | timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 33 | if cfg.continue_ckpt is None: 34 | writer = SummaryWriter(f'{log_dir}/runs/{timestamp}') 35 | 36 | "output" 37 | save_dir = os.path.join(output_dir, 'ckpt') 38 | os.makedirs(save_dir, exist_ok=True) 39 | 40 | "model" 41 | model = Model(cfg) 42 | if cfg.continue_ckpt is None: 43 | model = model.to(device) 44 | 45 | "data" 46 | dataset = get_dataloaders(cfg) 47 | train_loader = dataset['train'] 48 | # val_loader = dataset['valid'] 49 | 50 | "val" 51 | # val = Val(cfg, device) 52 | # val_trainset = ValTrainset(cfg, device) 53 | 54 | "optimizer" 55 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=float(cfg.base_lr), betas=(0.9, 0.999)) 56 | 57 | "lr_scheduler" 58 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.lr_sch_gamma) 59 | 60 | "loss" 61 | criterion = ComposeLoss(cfg).to(device) 62 | 63 | iteration = 0 64 | "continue train" 65 | if cfg.continue_ckpt is not None: 66 | ckpt_dict = torch.load(cfg.continue_ckpt, map_location='cpu') 67 | cfg.start_epoch = ckpt_dict['start_epoch'] 68 | weights = ckpt_dict['model'] 69 | model.load_state_dict(weights) 70 | model.to(device) 71 | optim = ckpt_dict['optimizer'] 72 | # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=ckpt_dict['optimizer']['param_groups'][0]['lr'], betas=(0.9, 0.999)) 73 | optimizer.load_state_dict(optim) 74 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.lr_sch_gamma, last_epoch=cfg.start_epoch-1) 75 | timestamp = '2023-07-10-23-27-58' 76 | writer = SummaryWriter(f'{log_dir}/runs/{timestamp}') 77 | iteration = cfg.start_epoch*len(train_loader) 78 | 79 | "train" 80 | for epoch in range(cfg.start_epoch, cfg.epochs): 81 | 82 | model.train() 83 | optimizer.zero_grad() 84 | if cfg.lr_sch_epoch is not None and (epoch+1) % cfg.lr_sch_epoch == 0: 85 | scheduler.step() 86 | loss_train_list = [] 87 | for b, data in enumerate(train_loader): 88 | audio, vertices, template, one_hot, subject_id, init_state = data['audio'], data['vertices'], data['template'], data['one_hot'], data['subject_id'], data['init_state'] 89 | vertices = vertices.to(device) 90 | label = vertices 91 | audio = audio.to(device) 92 | # text_label = text_label.to(device) 93 | template = template.to(device) 94 | one_hot = one_hot.to(device) 95 | id_label = torch.argmax(one_hot, dim=1) 96 | init_state = init_state.to(device) 97 | output = model(vertices, audio, template, init_state, id_label=id_label) 98 | # output = model(vertices, audio, template, one_hot) 99 | loss_dict = criterion(output, label, text_label=None, id_label=id_label) 100 | loss = loss_dict['loss'] 101 | loss.backward() 102 | optimizer.step() 103 | optimizer.zero_grad() 104 | loss_train_list.append(loss.item()) 105 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 106 | log = "(Epoch {}/{}, Batch {}/{}), lr {}, TRAIN LOSS:{:.9f}".format((epoch+1), cfg.epochs, b+1, len(train_loader), lr, loss) 107 | logger.info(log) 108 | iteration += 1 109 | if iteration%cfg.print_freq==0: 110 | for loss_item in loss_dict: 111 | if 'acc' not in loss_item: 112 | writer.add_scalar('train/{}'.format(loss_item), loss_dict[loss_item].item(), iteration) 113 | # debug 114 | # if loss_dict['content_contrastive_loss'].item() > 4.5e-6: 115 | # print(subject_id) 116 | # cal loss for epoch 117 | if b==0: 118 | loss_dict_epoch = {} 119 | for key in loss_dict: 120 | loss_dict_epoch[key] = loss_dict[key] 121 | else: 122 | for key in loss_dict_epoch: 123 | loss_dict_epoch[key] += loss_dict[key] 124 | 125 | for loss_item in loss_dict_epoch: 126 | writer.add_scalar('train/{}_epoch'.format(loss_item), loss_dict_epoch[loss_item].item()/(b+1), epoch+1) 127 | # writer.add_scalar('train/loss_epoch', np.mean(loss_train_list), epoch+1) 128 | 129 | if (epoch+1)%cfg.save_freq==0: 130 | state = {'params': params_dict['params'], 131 | 'model': model.state_dict(), 132 | 'optimizer': optimizer.state_dict(), 133 | 'start_epoch': epoch+1, 134 | } 135 | ckpt_dir = os.path.join(save_dir, 'Epoch_{}.pth'.format(epoch+1)) 136 | torch.save(state, ckpt_dir) 137 | 138 | # if cfg.content_grl_loss.w_decay is not None and (epoch+1)%cfg.content_grl_loss.w_decay == 0: 139 | # cfg.content_grl_loss.w = cfg.content_grl_loss.w*0.1 140 | # debug 141 | if cfg.content_grl_loss.w_decay is not None: 142 | if epoch+1 == 20: 143 | cfg.content_grl_loss.w = cfg.content_grl_loss.w*0.1 144 | elif epoch+1 == 40: 145 | cfg.content_grl_loss.w = cfg.content_grl_loss.w*0.5 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | main() -------------------------------------------------------------------------------- /utils/render_pyrender.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import cv2 4 | import numpy as np 5 | 6 | cv2.ocl.setUseOpenCL(False) 7 | cv2.setNumThreads(0) 8 | 9 | os.environ['PYOPENGL_PLATFORM'] = 'osmesa' #egl 10 | import pyrender 11 | import trimesh 12 | from psbody.mesh import Mesh 13 | 14 | 15 | # The implementation of rendering is borrowed from VOCA: https://github.com/TimoBolkart/voca/blob/master/utils/rendering.py 16 | def render_mesh_helper(args,mesh, t_center, rot=np.zeros(3), tex_img=None, z_offset=0): 17 | if args.dataset == "BIWI": 18 | camera_params = {'c': np.array([400, 400]), 19 | 'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]), 20 | 'f': np.array([4754.97941935 / 8, 4754.97941935 / 8])} 21 | elif args.dataset == "vocaset" or args.dataset == 'HDTF': 22 | camera_params = {'c': np.array([400, 400]), 23 | 'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]), 24 | 'f': np.array([4754.97941935 / 2, 4754.97941935 / 2])} 25 | 26 | frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800} 27 | # if args.dataset == 'HDTF': 28 | # frustum = {'near': 0.01, 'far': 3.0, 'height': 224, 'width': 224} 29 | 30 | mesh_copy = Mesh(mesh.v, mesh.f) 31 | mesh_copy.v[:] = cv2.Rodrigues(rot)[0].dot((mesh_copy.v-t_center).T).T+t_center 32 | 33 | intensity = 2.0 34 | 35 | primitive_material = pyrender.material.MetallicRoughnessMaterial( 36 | alphaMode='BLEND', 37 | baseColorFactor=[0.3, 0.3, 0.3, 1.0], 38 | metallicFactor=0.8, 39 | roughnessFactor=0.8 40 | ) 41 | 42 | 43 | tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v, faces=mesh_copy.f) 44 | render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=primitive_material,smooth=True) 45 | 46 | if args.background_black: 47 | scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0]) 48 | else: 49 | scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255]) 50 | 51 | camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0], 52 | fy=camera_params['f'][1], 53 | cx=camera_params['c'][0], 54 | cy=camera_params['c'][1], 55 | znear=frustum['near'], 56 | zfar=frustum['far']) 57 | 58 | scene.add(render_mesh, pose=np.eye(4)) 59 | 60 | camera_pose = np.eye(4) 61 | camera_pose[:3,3] = np.array([0, 0, 1.0-z_offset]) 62 | scene.add(camera, pose=[[1, 0, 0, 0], 63 | [0, 1, 0, 0], 64 | [0, 0, 1, 1], 65 | [0, 0, 0, 1]]) 66 | 67 | angle = np.pi / 6.0 68 | pos = camera_pose[:3,3] 69 | light_color = np.array([1., 1., 1.]) 70 | light = pyrender.DirectionalLight(color=light_color, intensity=intensity) 71 | 72 | light_pose = np.eye(4) 73 | light_pose[:3,3] = pos 74 | scene.add(light, pose=light_pose.copy()) 75 | 76 | light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos) 77 | scene.add(light, pose=light_pose.copy()) 78 | 79 | light_pose[:3,3] = cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos) 80 | scene.add(light, pose=light_pose.copy()) 81 | 82 | light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos) 83 | scene.add(light, pose=light_pose.copy()) 84 | 85 | light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos) 86 | scene.add(light, pose=light_pose.copy()) 87 | 88 | flags = pyrender.RenderFlags.SKIP_CULL_FACES 89 | try: 90 | r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height']) 91 | color, _ = r.render(scene, flags=flags) 92 | except: 93 | print('pyrender: Failed rendering frame') 94 | color = np.zeros((frustum['height'], frustum['width'], 3), dtype='uint8') 95 | 96 | return color[..., ::-1] --------------------------------------------------------------------------------