├── .gitignore ├── LICENSE ├── activation.py ├── data_utils ├── deepspeech_features │ ├── README.md │ ├── deepspeech_features.py │ ├── deepspeech_store.py │ ├── extract_ds_features.py │ ├── extract_wav.py │ └── fea_win.py ├── face_parsing │ ├── logger.py │ ├── model.py │ ├── resnet.py │ └── test.py ├── face_tracking │ ├── __init__.py │ ├── convert_BFM.py │ ├── data_loader.py │ ├── face_tracker.py │ ├── facemodel.py │ ├── geo_transform.py │ ├── render_3dmm.py │ ├── render_land.py │ └── util.py └── process.py ├── encoding.py ├── freqencoder ├── __init__.py ├── backend.py ├── freq.py ├── setup.py └── src │ ├── bindings.cpp │ ├── freqencoder.cu │ └── freqencoder.h ├── gridencoder ├── __init__.py ├── backend.py ├── grid.py ├── setup.py └── src │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── main.py ├── nerf ├── asr.py ├── gui.py ├── network.py ├── provider.py ├── renderer.py └── utils.py ├── raymarching ├── __init__.py ├── backend.py ├── raymarching.py ├── setup.py └── src │ ├── bindings.cpp │ ├── raymarching.cu │ └── raymarching.h ├── readme.md ├── requirements.txt ├── scripts ├── install_ext.sh ├── test_pretrained.sh ├── test_streaming.sh ├── train_obama_ds.sh └── train_obama_eo.sh ├── shencoder ├── __init__.py ├── backend.py ├── setup.py ├── sphere_harmonics.py └── src │ ├── bindings.cpp │ ├── shencoder.cu │ └── shencoder.h └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | *.egg-info/ 4 | *.so 5 | *.mp4 6 | 7 | tmp* 8 | trial*/ 9 | 10 | data 11 | data_utils/face_tracking/3DMM/* 12 | data_utils/face_parsing/79999_iter.pth 13 | 14 | pretrained 15 | *.mp4 16 | 17 | scripts/train_chris_eo.sh 18 | scripts/train_marco_eo.sh 19 | scripts/test.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 hawkey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.cuda.amp import custom_bwd, custom_fwd 4 | 5 | class _trunc_exp(Function): 6 | @staticmethod 7 | @custom_fwd(cast_inputs=torch.float32) # cast to float32 8 | def forward(ctx, x): 9 | ctx.save_for_backward(x) 10 | return torch.exp(x) 11 | 12 | @staticmethod 13 | @custom_bwd 14 | def backward(ctx, g): 15 | x = ctx.saved_tensors[0] 16 | return g * torch.exp(x.clamp(-15, 15)) 17 | 18 | trunc_exp = _trunc_exp.apply -------------------------------------------------------------------------------- /data_utils/deepspeech_features/README.md: -------------------------------------------------------------------------------- 1 | # Routines for DeepSpeech features processing 2 | Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model. 3 | 4 | ## Installation 5 | 6 | ``` 7 | pip3 install -r requirements.txt 8 | ``` 9 | 10 | ## Usage 11 | 12 | Generate wav files: 13 | ``` 14 | python3 extract_wav.py --in-video= 15 | ``` 16 | 17 | Generate files with DeepSpeech features: 18 | ``` 19 | python3 extract_ds_features.py --input= 20 | ``` 21 | -------------------------------------------------------------------------------- /data_utils/deepspeech_features/deepspeech_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepSpeech features processing routines. 3 | NB: Based on VOCA code. See the corresponding license restrictions. 4 | """ 5 | 6 | __all__ = ['conv_audios_to_deepspeech'] 7 | 8 | import numpy as np 9 | import warnings 10 | import resampy 11 | from scipy.io import wavfile 12 | from python_speech_features import mfcc 13 | import tensorflow as tf 14 | 15 | 16 | def conv_audios_to_deepspeech(audios, 17 | out_files, 18 | num_frames_info, 19 | deepspeech_pb_path, 20 | audio_window_size=1, 21 | audio_window_stride=1): 22 | """ 23 | Convert list of audio files into files with DeepSpeech features. 24 | 25 | Parameters 26 | ---------- 27 | audios : list of str or list of None 28 | Paths to input audio files. 29 | out_files : list of str 30 | Paths to output files with DeepSpeech features. 31 | num_frames_info : list of int 32 | List of numbers of frames. 33 | deepspeech_pb_path : str 34 | Path to DeepSpeech 0.1.0 frozen model. 35 | audio_window_size : int, default 16 36 | Audio window size. 37 | audio_window_stride : int, default 1 38 | Audio window stride. 39 | """ 40 | # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" 41 | graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net( 42 | deepspeech_pb_path) 43 | 44 | with tf.compat.v1.Session(graph=graph) as sess: 45 | for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info): 46 | print(audio_file_path) 47 | print(out_file_path) 48 | audio_sample_rate, audio = wavfile.read(audio_file_path) 49 | if audio.ndim != 1: 50 | warnings.warn( 51 | "Audio has multiple channels, the first channel is used") 52 | audio = audio[:, 0] 53 | ds_features = pure_conv_audio_to_deepspeech( 54 | audio=audio, 55 | audio_sample_rate=audio_sample_rate, 56 | audio_window_size=audio_window_size, 57 | audio_window_stride=audio_window_stride, 58 | num_frames=num_frames, 59 | net_fn=lambda x: sess.run( 60 | logits_ph, 61 | feed_dict={ 62 | input_node_ph: x[np.newaxis, ...], 63 | input_lengths_ph: [x.shape[0]]})) 64 | 65 | net_output = ds_features.reshape(-1, 29) 66 | win_size = 16 67 | zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) 68 | net_output = np.concatenate( 69 | (zero_pad, net_output, zero_pad), axis=0) 70 | windows = [] 71 | for window_index in range(0, net_output.shape[0] - win_size, 2): 72 | windows.append( 73 | net_output[window_index:window_index + win_size]) 74 | print(np.array(windows).shape) 75 | np.save(out_file_path, np.array(windows)) 76 | 77 | 78 | def prepare_deepspeech_net(deepspeech_pb_path): 79 | """ 80 | Load and prepare DeepSpeech network. 81 | 82 | Parameters 83 | ---------- 84 | deepspeech_pb_path : str 85 | Path to DeepSpeech 0.1.0 frozen model. 86 | 87 | Returns 88 | ------- 89 | graph : obj 90 | ThensorFlow graph. 91 | logits_ph : obj 92 | ThensorFlow placeholder for `logits`. 93 | input_node_ph : obj 94 | ThensorFlow placeholder for `input_node`. 95 | input_lengths_ph : obj 96 | ThensorFlow placeholder for `input_lengths`. 97 | """ 98 | # Load graph and place_holders: 99 | with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: 100 | graph_def = tf.compat.v1.GraphDef() 101 | graph_def.ParseFromString(f.read()) 102 | 103 | graph = tf.compat.v1.get_default_graph() 104 | tf.import_graph_def(graph_def, name="deepspeech") 105 | logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") 106 | input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") 107 | input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") 108 | 109 | return graph, logits_ph, input_node_ph, input_lengths_ph 110 | 111 | 112 | def pure_conv_audio_to_deepspeech(audio, 113 | audio_sample_rate, 114 | audio_window_size, 115 | audio_window_stride, 116 | num_frames, 117 | net_fn): 118 | """ 119 | Core routine for converting audion into DeepSpeech features. 120 | 121 | Parameters 122 | ---------- 123 | audio : np.array 124 | Audio data. 125 | audio_sample_rate : int 126 | Audio sample rate. 127 | audio_window_size : int 128 | Audio window size. 129 | audio_window_stride : int 130 | Audio window stride. 131 | num_frames : int or None 132 | Numbers of frames. 133 | net_fn : func 134 | Function for DeepSpeech model call. 135 | 136 | Returns 137 | ------- 138 | np.array 139 | DeepSpeech features. 140 | """ 141 | target_sample_rate = 16000 142 | if audio_sample_rate != target_sample_rate: 143 | resampled_audio = resampy.resample( 144 | x=audio.astype(np.float), 145 | sr_orig=audio_sample_rate, 146 | sr_new=target_sample_rate) 147 | else: 148 | resampled_audio = audio.astype(np.float) 149 | input_vector = conv_audio_to_deepspeech_input_vector( 150 | audio=resampled_audio.astype(np.int16), 151 | sample_rate=target_sample_rate, 152 | num_cepstrum=26, 153 | num_context=9) 154 | 155 | network_output = net_fn(input_vector) 156 | # print(network_output.shape) 157 | 158 | deepspeech_fps = 50 159 | video_fps = 50 # Change this option if video fps is different 160 | audio_len_s = float(audio.shape[0]) / audio_sample_rate 161 | if num_frames is None: 162 | num_frames = int(round(audio_len_s * video_fps)) 163 | else: 164 | video_fps = num_frames / audio_len_s 165 | network_output = interpolate_features( 166 | features=network_output[:, 0], 167 | input_rate=deepspeech_fps, 168 | output_rate=video_fps, 169 | output_len=num_frames) 170 | 171 | # Make windows: 172 | zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1])) 173 | network_output = np.concatenate( 174 | (zero_pad, network_output, zero_pad), axis=0) 175 | windows = [] 176 | for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride): 177 | windows.append( 178 | network_output[window_index:window_index + audio_window_size]) 179 | 180 | return np.array(windows) 181 | 182 | 183 | def conv_audio_to_deepspeech_input_vector(audio, 184 | sample_rate, 185 | num_cepstrum, 186 | num_context): 187 | """ 188 | Convert audio raw data into DeepSpeech input vector. 189 | 190 | Parameters 191 | ---------- 192 | audio : np.array 193 | Audio data. 194 | audio_sample_rate : int 195 | Audio sample rate. 196 | num_cepstrum : int 197 | Number of cepstrum. 198 | num_context : int 199 | Number of context. 200 | 201 | Returns 202 | ------- 203 | np.array 204 | DeepSpeech input vector. 205 | """ 206 | # Get mfcc coefficients: 207 | features = mfcc( 208 | signal=audio, 209 | samplerate=sample_rate, 210 | numcep=num_cepstrum) 211 | 212 | # We only keep every second feature (BiRNN stride = 2): 213 | features = features[::2] 214 | 215 | # One stride per time step in the input: 216 | num_strides = len(features) 217 | 218 | # Add empty initial and final contexts: 219 | empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) 220 | features = np.concatenate((empty_context, features, empty_context)) 221 | 222 | # Create a view into the array with overlapping strides of size 223 | # numcontext (past) + 1 (present) + numcontext (future): 224 | window_size = 2 * num_context + 1 225 | train_inputs = np.lib.stride_tricks.as_strided( 226 | features, 227 | shape=(num_strides, window_size, num_cepstrum), 228 | strides=(features.strides[0], 229 | features.strides[0], features.strides[1]), 230 | writeable=False) 231 | 232 | # Flatten the second and third dimensions: 233 | train_inputs = np.reshape(train_inputs, [num_strides, -1]) 234 | 235 | train_inputs = np.copy(train_inputs) 236 | train_inputs = (train_inputs - np.mean(train_inputs)) / \ 237 | np.std(train_inputs) 238 | 239 | return train_inputs 240 | 241 | 242 | def interpolate_features(features, 243 | input_rate, 244 | output_rate, 245 | output_len): 246 | """ 247 | Interpolate DeepSpeech features. 248 | 249 | Parameters 250 | ---------- 251 | features : np.array 252 | DeepSpeech features. 253 | input_rate : int 254 | input rate (FPS). 255 | output_rate : int 256 | Output rate (FPS). 257 | output_len : int 258 | Output data length. 259 | 260 | Returns 261 | ------- 262 | np.array 263 | Interpolated data. 264 | """ 265 | input_len = features.shape[0] 266 | num_features = features.shape[1] 267 | input_timestamps = np.arange(input_len) / float(input_rate) 268 | output_timestamps = np.arange(output_len) / float(output_rate) 269 | output_features = np.zeros((output_len, num_features)) 270 | for feature_idx in range(num_features): 271 | output_features[:, feature_idx] = np.interp( 272 | x=output_timestamps, 273 | xp=input_timestamps, 274 | fp=features[:, feature_idx]) 275 | return output_features 276 | -------------------------------------------------------------------------------- /data_utils/deepspeech_features/deepspeech_store.py: -------------------------------------------------------------------------------- 1 | """ 2 | Routines for loading DeepSpeech model. 3 | """ 4 | 5 | __all__ = ['get_deepspeech_model_file'] 6 | 7 | import os 8 | import zipfile 9 | import logging 10 | import hashlib 11 | 12 | 13 | deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' 14 | 15 | 16 | def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): 17 | """ 18 | Return location for the pretrained on local file system. This function will download from online model zoo when 19 | model cannot be found or has mismatch. The root directory will be created if it doesn't exist. 20 | 21 | Parameters 22 | ---------- 23 | local_model_store_dir_path : str, default $TENSORFLOW_HOME/models 24 | Location for keeping the model parameters. 25 | 26 | Returns 27 | ------- 28 | file_path 29 | Path to the requested pretrained model file. 30 | """ 31 | sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" 32 | file_name = "deepspeech-0_1_0-b90017e8.pb" 33 | local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) 34 | file_path = os.path.join(local_model_store_dir_path, file_name) 35 | if os.path.exists(file_path): 36 | if _check_sha1(file_path, sha1_hash): 37 | return file_path 38 | else: 39 | logging.warning("Mismatch in the content of model file detected. Downloading again.") 40 | else: 41 | logging.info("Model file not found. Downloading to {}.".format(file_path)) 42 | 43 | if not os.path.exists(local_model_store_dir_path): 44 | os.makedirs(local_model_store_dir_path) 45 | 46 | zip_file_path = file_path + ".zip" 47 | _download( 48 | url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( 49 | repo_url=deepspeech_features_repo_url, 50 | repo_release_tag="v0.0.1", 51 | file_name=file_name), 52 | path=zip_file_path, 53 | overwrite=True) 54 | with zipfile.ZipFile(zip_file_path) as zf: 55 | zf.extractall(local_model_store_dir_path) 56 | os.remove(zip_file_path) 57 | 58 | if _check_sha1(file_path, sha1_hash): 59 | return file_path 60 | else: 61 | raise ValueError("Downloaded file has different hash. Please try again.") 62 | 63 | 64 | def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): 65 | """ 66 | Download an given URL 67 | 68 | Parameters 69 | ---------- 70 | url : str 71 | URL to download 72 | path : str, optional 73 | Destination path to store downloaded file. By default stores to the 74 | current directory with same name as in url. 75 | overwrite : bool, optional 76 | Whether to overwrite destination file if already exists. 77 | sha1_hash : str, optional 78 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 79 | but doesn't match. 80 | retries : integer, default 5 81 | The number of times to attempt the download in case of failure or non 200 return codes 82 | verify_ssl : bool, default True 83 | Verify SSL certificates. 84 | 85 | Returns 86 | ------- 87 | str 88 | The file path of the downloaded file. 89 | """ 90 | import warnings 91 | try: 92 | import requests 93 | except ImportError: 94 | class requests_failed_to_import(object): 95 | pass 96 | requests = requests_failed_to_import 97 | 98 | if path is None: 99 | fname = url.split("/")[-1] 100 | # Empty filenames are invalid 101 | assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." 102 | else: 103 | path = os.path.expanduser(path) 104 | if os.path.isdir(path): 105 | fname = os.path.join(path, url.split("/")[-1]) 106 | else: 107 | fname = path 108 | assert retries >= 0, "Number of retries should be at least 0" 109 | 110 | if not verify_ssl: 111 | warnings.warn( 112 | "Unverified HTTPS request is being made (verify_ssl=False). " 113 | "Adding certificate verification is strongly advised.") 114 | 115 | if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): 116 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 117 | if not os.path.exists(dirname): 118 | os.makedirs(dirname) 119 | while retries + 1 > 0: 120 | # Disable pyling too broad Exception 121 | # pylint: disable=W0703 122 | try: 123 | print("Downloading {} from {}...".format(fname, url)) 124 | r = requests.get(url, stream=True, verify=verify_ssl) 125 | if r.status_code != 200: 126 | raise RuntimeError("Failed downloading url {}".format(url)) 127 | with open(fname, "wb") as f: 128 | for chunk in r.iter_content(chunk_size=1024): 129 | if chunk: # filter out keep-alive new chunks 130 | f.write(chunk) 131 | if sha1_hash and not _check_sha1(fname, sha1_hash): 132 | raise UserWarning("File {} is downloaded but the content hash does not match." 133 | " The repo may be outdated or download may be incomplete. " 134 | "If the `repo_url` is overridden, consider switching to " 135 | "the default repo.".format(fname)) 136 | break 137 | except Exception as e: 138 | retries -= 1 139 | if retries <= 0: 140 | raise e 141 | else: 142 | print("download failed, retrying, {} attempt{} left" 143 | .format(retries, "s" if retries > 1 else "")) 144 | 145 | return fname 146 | 147 | 148 | def _check_sha1(filename, sha1_hash): 149 | """ 150 | Check whether the sha1 hash of the file content matches the expected hash. 151 | 152 | Parameters 153 | ---------- 154 | filename : str 155 | Path to the file. 156 | sha1_hash : str 157 | Expected sha1 hash in hexadecimal digits. 158 | 159 | Returns 160 | ------- 161 | bool 162 | Whether the file content matches the expected hash. 163 | """ 164 | sha1 = hashlib.sha1() 165 | with open(filename, "rb") as f: 166 | while True: 167 | data = f.read(1048576) 168 | if not data: 169 | break 170 | sha1.update(data) 171 | 172 | return sha1.hexdigest() == sha1_hash 173 | -------------------------------------------------------------------------------- /data_utils/deepspeech_features/extract_ds_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for extracting DeepSpeech features from audio file. 3 | """ 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | import pandas as pd 9 | from deepspeech_store import get_deepspeech_model_file 10 | from deepspeech_features import conv_audios_to_deepspeech 11 | 12 | 13 | def parse_args(): 14 | """ 15 | Create python script parameters. 16 | Returns 17 | ------- 18 | ArgumentParser 19 | Resulted args. 20 | """ 21 | parser = argparse.ArgumentParser( 22 | description="Extract DeepSpeech features from audio file", 23 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | parser.add_argument( 25 | "--input", 26 | type=str, 27 | required=True, 28 | help="path to input audio file or directory") 29 | parser.add_argument( 30 | "--output", 31 | type=str, 32 | help="path to output file with DeepSpeech features") 33 | parser.add_argument( 34 | "--deepspeech", 35 | type=str, 36 | help="path to DeepSpeech 0.1.0 frozen model") 37 | parser.add_argument( 38 | "--metainfo", 39 | type=str, 40 | help="path to file with meta-information") 41 | 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def extract_features(in_audios, 47 | out_files, 48 | deepspeech_pb_path, 49 | metainfo_file_path=None): 50 | """ 51 | Real extract audio from video file. 52 | Parameters 53 | ---------- 54 | in_audios : list of str 55 | Paths to input audio files. 56 | out_files : list of str 57 | Paths to output files with DeepSpeech features. 58 | deepspeech_pb_path : str 59 | Path to DeepSpeech 0.1.0 frozen model. 60 | metainfo_file_path : str, default None 61 | Path to file with meta-information. 62 | """ 63 | #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" 64 | if metainfo_file_path is None: 65 | num_frames_info = [None] * len(in_audios) 66 | else: 67 | train_df = pd.read_csv( 68 | metainfo_file_path, 69 | sep="\t", 70 | index_col=False, 71 | dtype={"Id": np.int, "File": np.unicode, "Count": np.int}) 72 | num_frames_info = train_df["Count"].values 73 | assert (len(num_frames_info) == len(in_audios)) 74 | 75 | for i, in_audio in enumerate(in_audios): 76 | if not out_files[i]: 77 | file_stem, _ = os.path.splitext(in_audio) 78 | out_files[i] = file_stem + ".npy" 79 | #print(out_files[i]) 80 | conv_audios_to_deepspeech( 81 | audios=in_audios, 82 | out_files=out_files, 83 | num_frames_info=num_frames_info, 84 | deepspeech_pb_path=deepspeech_pb_path) 85 | 86 | 87 | def main(): 88 | """ 89 | Main body of script. 90 | """ 91 | args = parse_args() 92 | in_audio = os.path.expanduser(args.input) 93 | if not os.path.exists(in_audio): 94 | raise Exception("Input file/directory doesn't exist: {}".format(in_audio)) 95 | deepspeech_pb_path = args.deepspeech 96 | #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" 97 | if deepspeech_pb_path is None: 98 | deepspeech_pb_path = "" 99 | if deepspeech_pb_path: 100 | deepspeech_pb_path = os.path.expanduser(args.deepspeech) 101 | if not os.path.exists(deepspeech_pb_path): 102 | deepspeech_pb_path = get_deepspeech_model_file() 103 | if os.path.isfile(in_audio): 104 | extract_features( 105 | in_audios=[in_audio], 106 | out_files=[args.output], 107 | deepspeech_pb_path=deepspeech_pb_path, 108 | metainfo_file_path=args.metainfo) 109 | else: 110 | audio_file_paths = [] 111 | for file_name in os.listdir(in_audio): 112 | if not os.path.isfile(os.path.join(in_audio, file_name)): 113 | continue 114 | _, file_ext = os.path.splitext(file_name) 115 | if file_ext.lower() == ".wav": 116 | audio_file_path = os.path.join(in_audio, file_name) 117 | audio_file_paths.append(audio_file_path) 118 | audio_file_paths = sorted(audio_file_paths) 119 | out_file_paths = [""] * len(audio_file_paths) 120 | extract_features( 121 | in_audios=audio_file_paths, 122 | out_files=out_file_paths, 123 | deepspeech_pb_path=deepspeech_pb_path, 124 | metainfo_file_path=args.metainfo) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | 130 | -------------------------------------------------------------------------------- /data_utils/deepspeech_features/extract_wav.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for extracting audio (16-bit, mono, 22000 Hz) from video file. 3 | """ 4 | 5 | import os 6 | import argparse 7 | import subprocess 8 | 9 | 10 | def parse_args(): 11 | """ 12 | Create python script parameters. 13 | 14 | Returns 15 | ------- 16 | ArgumentParser 17 | Resulted args. 18 | """ 19 | parser = argparse.ArgumentParser( 20 | description="Extract audio from video file", 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument( 23 | "--in-video", 24 | type=str, 25 | required=True, 26 | help="path to input video file or directory") 27 | parser.add_argument( 28 | "--out-audio", 29 | type=str, 30 | help="path to output audio file") 31 | 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def extract_audio(in_video, 37 | out_audio): 38 | """ 39 | Real extract audio from video file. 40 | 41 | Parameters 42 | ---------- 43 | in_video : str 44 | Path to input video file. 45 | out_audio : str 46 | Path to output audio file. 47 | """ 48 | if not out_audio: 49 | file_stem, _ = os.path.splitext(in_video) 50 | out_audio = file_stem + ".wav" 51 | # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}" 52 | # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" 53 | # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" 54 | command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}" 55 | subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True) 56 | 57 | 58 | def main(): 59 | """ 60 | Main body of script. 61 | """ 62 | args = parse_args() 63 | in_video = os.path.expanduser(args.in_video) 64 | if not os.path.exists(in_video): 65 | raise Exception("Input file/directory doesn't exist: {}".format(in_video)) 66 | if os.path.isfile(in_video): 67 | extract_audio( 68 | in_video=in_video, 69 | out_audio=args.out_audio) 70 | else: 71 | video_file_paths = [] 72 | for file_name in os.listdir(in_video): 73 | if not os.path.isfile(os.path.join(in_video, file_name)): 74 | continue 75 | _, file_ext = os.path.splitext(file_name) 76 | if file_ext.lower() in (".mp4", ".mkv", ".avi"): 77 | video_file_path = os.path.join(in_video, file_name) 78 | video_file_paths.append(video_file_path) 79 | video_file_paths = sorted(video_file_paths) 80 | for video_file_path in video_file_paths: 81 | extract_audio( 82 | in_video=video_file_path, 83 | out_audio="") 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /data_utils/deepspeech_features/fea_win.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | net_output = np.load('french.ds.npy').reshape(-1, 29) 4 | win_size = 16 5 | zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) 6 | net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0) 7 | windows = [] 8 | for window_index in range(0, net_output.shape[0] - win_size, 2): 9 | windows.append(net_output[window_index:window_index + win_size]) 10 | print(np.array(windows).shape) 11 | np.save('aud_french.npy', np.array(windows)) 12 | -------------------------------------------------------------------------------- /data_utils/face_parsing/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os.path as osp 6 | import time 7 | import sys 8 | import logging 9 | 10 | import torch.distributed as dist 11 | 12 | 13 | def setup_logger(logpth): 14 | logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) 15 | logfile = osp.join(logpth, logfile) 16 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 17 | log_level = logging.INFO 18 | if dist.is_initialized() and not dist.get_rank()==0: 19 | log_level = logging.ERROR 20 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) 21 | logging.root.addHandler(logging.StreamHandler()) 22 | 23 | 24 | -------------------------------------------------------------------------------- /data_utils/face_parsing/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size = ks, 20 | stride = stride, 21 | padding = padding, 22 | bias = False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 36 | 37 | class BiSeNetOutput(nn.Module): 38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 39 | super(BiSeNetOutput, self).__init__() 40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 42 | self.init_weight() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.conv_out(x) 47 | return x 48 | 49 | def init_weight(self): 50 | for ly in self.children(): 51 | if isinstance(ly, nn.Conv2d): 52 | nn.init.kaiming_normal_(ly.weight, a=1) 53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 54 | 55 | def get_params(self): 56 | wd_params, nowd_params = [], [] 57 | for name, module in self.named_modules(): 58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 59 | wd_params.append(module.weight) 60 | if not module.bias is None: 61 | nowd_params.append(module.bias) 62 | elif isinstance(module, nn.BatchNorm2d): 63 | nowd_params += list(module.parameters()) 64 | return wd_params, nowd_params 65 | 66 | 67 | class AttentionRefinementModule(nn.Module): 68 | def __init__(self, in_chan, out_chan, *args, **kwargs): 69 | super(AttentionRefinementModule, self).__init__() 70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 72 | self.bn_atten = nn.BatchNorm2d(out_chan) 73 | self.sigmoid_atten = nn.Sigmoid() 74 | self.init_weight() 75 | 76 | def forward(self, x): 77 | feat = self.conv(x) 78 | atten = F.avg_pool2d(feat, feat.size()[2:]) 79 | atten = self.conv_atten(atten) 80 | atten = self.bn_atten(atten) 81 | atten = self.sigmoid_atten(atten) 82 | out = torch.mul(feat, atten) 83 | return out 84 | 85 | def init_weight(self): 86 | for ly in self.children(): 87 | if isinstance(ly, nn.Conv2d): 88 | nn.init.kaiming_normal_(ly.weight, a=1) 89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 90 | 91 | 92 | class ContextPath(nn.Module): 93 | def __init__(self, *args, **kwargs): 94 | super(ContextPath, self).__init__() 95 | self.resnet = Resnet18() 96 | self.arm16 = AttentionRefinementModule(256, 128) 97 | self.arm32 = AttentionRefinementModule(512, 128) 98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 101 | 102 | self.init_weight() 103 | 104 | def forward(self, x): 105 | H0, W0 = x.size()[2:] 106 | feat8, feat16, feat32 = self.resnet(x) 107 | H8, W8 = feat8.size()[2:] 108 | H16, W16 = feat16.size()[2:] 109 | H32, W32 = feat32.size()[2:] 110 | 111 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 112 | avg = self.conv_avg(avg) 113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 114 | 115 | feat32_arm = self.arm32(feat32) 116 | feat32_sum = feat32_arm + avg_up 117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 118 | feat32_up = self.conv_head32(feat32_up) 119 | 120 | feat16_arm = self.arm16(feat16) 121 | feat16_sum = feat16_arm + feat32_up 122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 123 | feat16_up = self.conv_head16(feat16_up) 124 | 125 | return feat8, feat16_up, feat32_up # x8, x8, x16 126 | 127 | def init_weight(self): 128 | for ly in self.children(): 129 | if isinstance(ly, nn.Conv2d): 130 | nn.init.kaiming_normal_(ly.weight, a=1) 131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 132 | 133 | def get_params(self): 134 | wd_params, nowd_params = [], [] 135 | for name, module in self.named_modules(): 136 | if isinstance(module, (nn.Linear, nn.Conv2d)): 137 | wd_params.append(module.weight) 138 | if not module.bias is None: 139 | nowd_params.append(module.bias) 140 | elif isinstance(module, nn.BatchNorm2d): 141 | nowd_params += list(module.parameters()) 142 | return wd_params, nowd_params 143 | 144 | 145 | ### This is not used, since I replace this with the resnet feature with the same size 146 | class SpatialPath(nn.Module): 147 | def __init__(self, *args, **kwargs): 148 | super(SpatialPath, self).__init__() 149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 153 | self.init_weight() 154 | 155 | def forward(self, x): 156 | feat = self.conv1(x) 157 | feat = self.conv2(feat) 158 | feat = self.conv3(feat) 159 | feat = self.conv_out(feat) 160 | return feat 161 | 162 | def init_weight(self): 163 | for ly in self.children(): 164 | if isinstance(ly, nn.Conv2d): 165 | nn.init.kaiming_normal_(ly.weight, a=1) 166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 167 | 168 | def get_params(self): 169 | wd_params, nowd_params = [], [] 170 | for name, module in self.named_modules(): 171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 172 | wd_params.append(module.weight) 173 | if not module.bias is None: 174 | nowd_params.append(module.bias) 175 | elif isinstance(module, nn.BatchNorm2d): 176 | nowd_params += list(module.parameters()) 177 | return wd_params, nowd_params 178 | 179 | 180 | class FeatureFusionModule(nn.Module): 181 | def __init__(self, in_chan, out_chan, *args, **kwargs): 182 | super(FeatureFusionModule, self).__init__() 183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 184 | self.conv1 = nn.Conv2d(out_chan, 185 | out_chan//4, 186 | kernel_size = 1, 187 | stride = 1, 188 | padding = 0, 189 | bias = False) 190 | self.conv2 = nn.Conv2d(out_chan//4, 191 | out_chan, 192 | kernel_size = 1, 193 | stride = 1, 194 | padding = 0, 195 | bias = False) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.sigmoid = nn.Sigmoid() 198 | self.init_weight() 199 | 200 | def forward(self, fsp, fcp): 201 | fcat = torch.cat([fsp, fcp], dim=1) 202 | feat = self.convblk(fcat) 203 | atten = F.avg_pool2d(feat, feat.size()[2:]) 204 | atten = self.conv1(atten) 205 | atten = self.relu(atten) 206 | atten = self.conv2(atten) 207 | atten = self.sigmoid(atten) 208 | feat_atten = torch.mul(feat, atten) 209 | feat_out = feat_atten + feat 210 | return feat_out 211 | 212 | def init_weight(self): 213 | for ly in self.children(): 214 | if isinstance(ly, nn.Conv2d): 215 | nn.init.kaiming_normal_(ly.weight, a=1) 216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 217 | 218 | def get_params(self): 219 | wd_params, nowd_params = [], [] 220 | for name, module in self.named_modules(): 221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 222 | wd_params.append(module.weight) 223 | if not module.bias is None: 224 | nowd_params.append(module.bias) 225 | elif isinstance(module, nn.BatchNorm2d): 226 | nowd_params += list(module.parameters()) 227 | return wd_params, nowd_params 228 | 229 | 230 | class BiSeNet(nn.Module): 231 | def __init__(self, n_classes, *args, **kwargs): 232 | super(BiSeNet, self).__init__() 233 | self.cp = ContextPath() 234 | ## here self.sp is deleted 235 | self.ffm = FeatureFusionModule(256, 256) 236 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | H, W = x.size()[2:] 243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 245 | feat_fuse = self.ffm(feat_sp, feat_cp8) 246 | 247 | feat_out = self.conv_out(feat_fuse) 248 | feat_out16 = self.conv_out16(feat_cp8) 249 | feat_out32 = self.conv_out32(feat_cp16) 250 | 251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 254 | 255 | # return feat_out, feat_out16, feat_out32 256 | return feat_out 257 | 258 | def init_weight(self): 259 | for ly in self.children(): 260 | if isinstance(ly, nn.Conv2d): 261 | nn.init.kaiming_normal_(ly.weight, a=1) 262 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 263 | 264 | def get_params(self): 265 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 266 | for name, child in self.named_children(): 267 | child_wd_params, child_nowd_params = child.get_params() 268 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 269 | lr_mul_wd_params += child_wd_params 270 | lr_mul_nowd_params += child_nowd_params 271 | else: 272 | wd_params += child_wd_params 273 | nowd_params += child_nowd_params 274 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 275 | 276 | 277 | if __name__ == "__main__": 278 | net = BiSeNet(19) 279 | net.cuda() 280 | net.eval() 281 | in_ten = torch.randn(16, 3, 640, 480).cuda() 282 | out, out16, out32 = net(in_ten) 283 | print(out.shape) 284 | 285 | net.get_params() 286 | -------------------------------------------------------------------------------- /data_utils/face_parsing/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /data_utils/face_parsing/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | import numpy as np 4 | from model import BiSeNet 5 | 6 | import torch 7 | 8 | import os 9 | import os.path as osp 10 | 11 | from PIL import Image 12 | import torchvision.transforms as transforms 13 | import cv2 14 | from pathlib import Path 15 | import configargparse 16 | import tqdm 17 | 18 | # import ttach as tta 19 | 20 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', 21 | img_size=(512, 512)): 22 | im = np.array(im) 23 | vis_im = im.copy().astype(np.uint8) 24 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 25 | vis_parsing_anno = cv2.resize( 26 | vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 27 | vis_parsing_anno_color = np.zeros( 28 | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255 29 | 30 | num_of_class = np.max(vis_parsing_anno) 31 | # print(num_of_class) 32 | for pi in range(1, 14): 33 | index = np.where(vis_parsing_anno == pi) 34 | vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) 35 | 36 | for pi in range(14, 16): 37 | index = np.where(vis_parsing_anno == pi) 38 | vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) 39 | for pi in range(16, 17): 40 | index = np.where(vis_parsing_anno == pi) 41 | vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) 42 | for pi in range(17, num_of_class+1): 43 | index = np.where(vis_parsing_anno == pi) 44 | vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) 45 | 46 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 47 | index = np.where(vis_parsing_anno == num_of_class-1) 48 | vis_im = cv2.resize(vis_parsing_anno_color, img_size, 49 | interpolation=cv2.INTER_NEAREST) 50 | if save_im: 51 | cv2.imwrite(save_path, vis_im) 52 | 53 | 54 | def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): 55 | 56 | Path(respth).mkdir(parents=True, exist_ok=True) 57 | 58 | print(f'[INFO] loading model...') 59 | n_classes = 19 60 | net = BiSeNet(n_classes=n_classes) 61 | net.cuda() 62 | net.load_state_dict(torch.load(cp)) 63 | net.eval() 64 | 65 | to_tensor = transforms.Compose([ 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 68 | ]) 69 | 70 | image_paths = os.listdir(dspth) 71 | 72 | with torch.no_grad(): 73 | for image_path in tqdm.tqdm(image_paths): 74 | if image_path.endswith('.jpg') or image_path.endswith('.png'): 75 | img = Image.open(osp.join(dspth, image_path)) 76 | ori_size = img.size 77 | image = img.resize((512, 512), Image.BILINEAR) 78 | image = image.convert("RGB") 79 | img = to_tensor(image) 80 | 81 | # test-time augmentation. 82 | inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512] 83 | outputs = net(inputs.cuda()) 84 | parsing = outputs.mean(0).cpu().numpy().argmax(0) 85 | 86 | image_path = int(image_path[:-4]) 87 | image_path = str(image_path) + '.png' 88 | 89 | vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = configargparse.ArgumentParser() 94 | parser.add_argument('--respath', type=str, default='./result/', help='result path for label') 95 | parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') 96 | parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') 97 | args = parser.parse_args() 98 | evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) 99 | -------------------------------------------------------------------------------- /data_utils/face_tracking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/RAD-NeRF/0de5ed259592592294677ad6cf7605f478a0de57/data_utils/face_tracking/__init__.py -------------------------------------------------------------------------------- /data_utils/face_tracking/convert_BFM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | 4 | original_BFM = loadmat("3DMM/01_MorphableModel.mat") 5 | sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"] 6 | 7 | shapePC = original_BFM["shapePC"] 8 | shapeEV = original_BFM["shapeEV"] 9 | shapeMU = original_BFM["shapeMU"] 10 | texPC = original_BFM["texPC"] 11 | texEV = original_BFM["texEV"] 12 | texMU = original_BFM["texMU"] 13 | 14 | b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) 15 | mu_shape = shapeMU.reshape(-1, 3) 16 | 17 | b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) 18 | mu_tex = texMU.reshape(-1, 3) 19 | 20 | b_shape = b_shape[:, sub_inds, :].reshape(199, -1) 21 | mu_shape = mu_shape[sub_inds, :].reshape(-1) 22 | b_tex = b_tex[:, sub_inds, :].reshape(199, -1) 23 | mu_tex = mu_tex[sub_inds, :].reshape(-1) 24 | 25 | exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item() 26 | np.save( 27 | "3DMM/3DMM_info.npy", 28 | { 29 | "mu_shape": mu_shape, 30 | "b_shape": b_shape, 31 | "sig_shape": shapeEV.reshape(-1), 32 | "mu_exp": exp_info["mu_exp"], 33 | "b_exp": exp_info["base_exp"], 34 | "sig_exp": exp_info["sig_exp"], 35 | "mu_tex": mu_tex, 36 | "b_tex": b_tex, 37 | "sig_tex": texEV.reshape(-1), 38 | }, 39 | ) 40 | -------------------------------------------------------------------------------- /data_utils/face_tracking/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def load_dir(path, start, end): 7 | lmss = [] 8 | imgs_paths = [] 9 | for i in range(start, end): 10 | if os.path.isfile(os.path.join(path, str(i) + ".lms")): 11 | lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32) 12 | lmss.append(lms) 13 | imgs_paths.append(os.path.join(path, str(i) + ".jpg")) 14 | lmss = np.stack(lmss) 15 | lmss = torch.as_tensor(lmss).cuda() 16 | return lmss, imgs_paths 17 | -------------------------------------------------------------------------------- /data_utils/face_tracking/face_tracker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | from pathlib import Path 6 | import torch 7 | import numpy as np 8 | from data_loader import load_dir 9 | from facemodel import Face_3DMM 10 | from util import * 11 | from render_3dmm import Render_3DMM 12 | 13 | 14 | # torch.autograd.set_detect_anomaly(True) 15 | 16 | dir_path = os.path.dirname(os.path.realpath(__file__)) 17 | 18 | 19 | def set_requires_grad(tensor_list): 20 | for tensor in tensor_list: 21 | tensor.requires_grad = True 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--path", type=str, default="obama/ori_imgs", help="idname of target person" 27 | ) 28 | parser.add_argument("--img_h", type=int, default=512, help="image height") 29 | parser.add_argument("--img_w", type=int, default=512, help="image width") 30 | parser.add_argument("--frame_num", type=int, default=11000, help="image number") 31 | args = parser.parse_args() 32 | 33 | start_id = 0 34 | end_id = args.frame_num 35 | 36 | lms, img_paths = load_dir(args.path, start_id, end_id) 37 | num_frames = lms.shape[0] 38 | h, w = args.img_h, args.img_w 39 | cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda() 40 | id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650 41 | model_3dmm = Face_3DMM( 42 | os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num 43 | ) 44 | 45 | # only use one image per 40 to do fit the focal length 46 | sel_ids = np.arange(0, num_frames, 40) 47 | sel_num = sel_ids.shape[0] 48 | arg_focal = 1600 49 | arg_landis = 1e5 50 | 51 | print(f'[INFO] fitting focal length...') 52 | 53 | # fit the focal length 54 | for focal in range(600, 1500, 100): 55 | id_para = lms.new_zeros((1, id_dim), requires_grad=True) 56 | exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True) 57 | euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True) 58 | trans = lms.new_zeros((sel_num, 3), requires_grad=True) 59 | trans.data[:, 2] -= 7 60 | focal_length = lms.new_zeros(1, requires_grad=False) 61 | focal_length.data += focal 62 | set_requires_grad([id_para, exp_para, euler_angle, trans]) 63 | 64 | optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) 65 | optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1) 66 | 67 | for iter in range(2000): 68 | id_para_batch = id_para.expand(sel_num, -1) 69 | geometry = model_3dmm.get_3dlandmarks( 70 | id_para_batch, exp_para, euler_angle, trans, focal_length, cxy 71 | ) 72 | proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) 73 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) 74 | loss = loss_lan 75 | optimizer_frame.zero_grad() 76 | loss.backward() 77 | optimizer_frame.step() 78 | # if iter % 100 == 0: 79 | # print(focal, 'pose', iter, loss.item()) 80 | 81 | for iter in range(2500): 82 | id_para_batch = id_para.expand(sel_num, -1) 83 | geometry = model_3dmm.get_3dlandmarks( 84 | id_para_batch, exp_para, euler_angle, trans, focal_length, cxy 85 | ) 86 | proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) 87 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) 88 | loss_regid = torch.mean(id_para * id_para) 89 | loss_regexp = torch.mean(exp_para * exp_para) 90 | loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 91 | optimizer_idexp.zero_grad() 92 | optimizer_frame.zero_grad() 93 | loss.backward() 94 | optimizer_idexp.step() 95 | optimizer_frame.step() 96 | # if iter % 100 == 0: 97 | # print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) 98 | 99 | if iter % 1500 == 0 and iter >= 1500: 100 | for param_group in optimizer_idexp.param_groups: 101 | param_group["lr"] *= 0.2 102 | for param_group in optimizer_frame.param_groups: 103 | param_group["lr"] *= 0.2 104 | 105 | print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item()) 106 | 107 | if loss_lan.item() < arg_landis: 108 | arg_landis = loss_lan.item() 109 | arg_focal = focal 110 | 111 | print("[INFO] find best focal:", arg_focal) 112 | 113 | print(f'[INFO] coarse fitting...') 114 | 115 | # for all frames, do a coarse fitting ??? 116 | id_para = lms.new_zeros((1, id_dim), requires_grad=True) 117 | exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) 118 | tex_para = lms.new_zeros( 119 | (1, tex_dim), requires_grad=True 120 | ) # not optimized in this block ??? 121 | euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) 122 | trans = lms.new_zeros((num_frames, 3), requires_grad=True) 123 | light_para = lms.new_zeros((num_frames, 27), requires_grad=True) 124 | trans.data[:, 2] -= 7 # ??? 125 | focal_length = lms.new_zeros(1, requires_grad=True) 126 | focal_length.data += arg_focal 127 | 128 | set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para]) 129 | 130 | optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) 131 | optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1) 132 | 133 | for iter in range(1500): 134 | id_para_batch = id_para.expand(num_frames, -1) 135 | geometry = model_3dmm.get_3dlandmarks( 136 | id_para_batch, exp_para, euler_angle, trans, focal_length, cxy 137 | ) 138 | proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) 139 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) 140 | loss = loss_lan 141 | optimizer_frame.zero_grad() 142 | loss.backward() 143 | optimizer_frame.step() 144 | if iter == 1000: 145 | for param_group in optimizer_frame.param_groups: 146 | param_group["lr"] = 0.1 147 | # if iter % 100 == 0: 148 | # print('pose', iter, loss.item()) 149 | 150 | for param_group in optimizer_frame.param_groups: 151 | param_group["lr"] = 0.1 152 | 153 | for iter in range(2000): 154 | id_para_batch = id_para.expand(num_frames, -1) 155 | geometry = model_3dmm.get_3dlandmarks( 156 | id_para_batch, exp_para, euler_angle, trans, focal_length, cxy 157 | ) 158 | proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) 159 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) 160 | loss_regid = torch.mean(id_para * id_para) 161 | loss_regexp = torch.mean(exp_para * exp_para) 162 | loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 163 | optimizer_idexp.zero_grad() 164 | optimizer_frame.zero_grad() 165 | loss.backward() 166 | optimizer_idexp.step() 167 | optimizer_frame.step() 168 | # if iter % 100 == 0: 169 | # print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) 170 | if iter % 1000 == 0 and iter >= 1000: 171 | for param_group in optimizer_idexp.param_groups: 172 | param_group["lr"] *= 0.2 173 | for param_group in optimizer_frame.param_groups: 174 | param_group["lr"] *= 0.2 175 | 176 | print(loss_lan.item(), torch.mean(trans[:, 2]).item()) 177 | 178 | print(f'[INFO] fitting light...') 179 | 180 | batch_size = 64 181 | 182 | device_default = torch.device("cuda:0") 183 | device_render = torch.device("cuda:0") 184 | renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) 185 | 186 | sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] 187 | imgs = [] 188 | for sel_id in sel_ids: 189 | imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) 190 | imgs = np.stack(imgs) 191 | sel_imgs = torch.as_tensor(imgs).cuda() 192 | sel_lms = lms[sel_ids] 193 | sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) 194 | set_requires_grad([sel_light]) 195 | 196 | optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1) 197 | optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01) 198 | 199 | for iter in range(71): 200 | sel_exp_para, sel_euler, sel_trans = ( 201 | exp_para[sel_ids], 202 | euler_angle[sel_ids], 203 | trans[sel_ids], 204 | ) 205 | sel_id_para = id_para.expand(batch_size, -1) 206 | geometry = model_3dmm.get_3dlandmarks( 207 | sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy 208 | ) 209 | proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) 210 | 211 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) 212 | loss_regid = torch.mean(id_para * id_para) 213 | loss_regexp = torch.mean(sel_exp_para * sel_exp_para) 214 | 215 | sel_tex_para = tex_para.expand(batch_size, -1) 216 | sel_texture = model_3dmm.forward_tex(sel_tex_para) 217 | geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) 218 | rott_geo = forward_rott(geometry, sel_euler, sel_trans) 219 | render_imgs = renderer( 220 | rott_geo.to(device_render), 221 | sel_texture.to(device_render), 222 | sel_light.to(device_render), 223 | ) 224 | render_imgs = render_imgs.to(device_default) 225 | 226 | mask = (render_imgs[:, :, :, 3]).detach() > 0.0 227 | render_proj = sel_imgs.clone() 228 | render_proj[mask] = render_imgs[mask][..., :3].byte() 229 | loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) 230 | 231 | if iter > 50: 232 | loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8 233 | else: 234 | loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0 235 | 236 | optimizer_tl.zero_grad() 237 | optimizer_id_frame.zero_grad() 238 | loss.backward() 239 | 240 | optimizer_tl.step() 241 | optimizer_id_frame.step() 242 | 243 | if iter % 50 == 0 and iter > 0: 244 | for param_group in optimizer_id_frame.param_groups: 245 | param_group["lr"] *= 0.2 246 | for param_group in optimizer_tl.param_groups: 247 | param_group["lr"] *= 0.2 248 | # print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item()) 249 | 250 | 251 | light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1) 252 | light_para.data = light_mean 253 | 254 | exp_para = exp_para.detach() 255 | euler_angle = euler_angle.detach() 256 | trans = trans.detach() 257 | light_para = light_para.detach() 258 | 259 | print(f'[INFO] fine frame-wise fitting...') 260 | 261 | for i in range(int((num_frames - 1) / batch_size + 1)): 262 | 263 | if (i + 1) * batch_size > num_frames: 264 | start_n = num_frames - batch_size 265 | sel_ids = np.arange(num_frames - batch_size, num_frames) 266 | else: 267 | start_n = i * batch_size 268 | sel_ids = np.arange(i * batch_size, i * batch_size + batch_size) 269 | 270 | imgs = [] 271 | for sel_id in sel_ids: 272 | imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) 273 | imgs = np.stack(imgs) 274 | sel_imgs = torch.as_tensor(imgs).cuda() 275 | sel_lms = lms[sel_ids] 276 | 277 | sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True) 278 | sel_exp_para.data = exp_para[sel_ids].clone() 279 | sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True) 280 | sel_euler.data = euler_angle[sel_ids].clone() 281 | sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) 282 | sel_trans.data = trans[sel_ids].clone() 283 | sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) 284 | sel_light.data = light_para[sel_ids].clone() 285 | 286 | set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light]) 287 | 288 | optimizer_cur_batch = torch.optim.Adam( 289 | [sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005 290 | ) 291 | 292 | sel_id_para = id_para.expand(batch_size, -1).detach() 293 | sel_tex_para = tex_para.expand(batch_size, -1).detach() 294 | 295 | pre_num = 5 296 | 297 | if i > 0: 298 | pre_ids = np.arange(start_n - pre_num, start_n) 299 | 300 | for iter in range(50): 301 | 302 | geometry = model_3dmm.get_3dlandmarks( 303 | sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy 304 | ) 305 | proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) 306 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) 307 | loss_regexp = torch.mean(sel_exp_para * sel_exp_para) 308 | 309 | sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) 310 | sel_texture = model_3dmm.forward_tex(sel_tex_para) 311 | geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) 312 | rott_geo = forward_rott(geometry, sel_euler, sel_trans) 313 | render_imgs = renderer( 314 | rott_geo.to(device_render), 315 | sel_texture.to(device_render), 316 | sel_light.to(device_render), 317 | ) 318 | render_imgs = render_imgs.to(device_default) 319 | 320 | mask = (render_imgs[:, :, :, 3]).detach() > 0.0 321 | 322 | loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) 323 | 324 | if i > 0: 325 | geometry_lap = model_3dmm.forward_geo_sub( 326 | id_para.expand(batch_size + pre_num, -1).detach(), 327 | torch.cat((exp_para[pre_ids].detach(), sel_exp_para)), 328 | model_3dmm.rigid_ids, 329 | ) 330 | rott_geo_lap = forward_rott( 331 | geometry_lap, 332 | torch.cat((euler_angle[pre_ids].detach(), sel_euler)), 333 | torch.cat((trans[pre_ids].detach(), sel_trans)), 334 | ) 335 | loss_lap = cal_lap_loss( 336 | [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] 337 | ) 338 | else: 339 | geometry_lap = model_3dmm.forward_geo_sub( 340 | id_para.expand(batch_size, -1).detach(), 341 | sel_exp_para, 342 | model_3dmm.rigid_ids, 343 | ) 344 | rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans) 345 | loss_lap = cal_lap_loss( 346 | [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] 347 | ) 348 | 349 | 350 | if iter > 30: 351 | loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0 352 | else: 353 | loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0 354 | 355 | optimizer_cur_batch.zero_grad() 356 | loss.backward() 357 | optimizer_cur_batch.step() 358 | 359 | # if iter % 10 == 0: 360 | # print( 361 | # i, 362 | # iter, 363 | # loss_col.item(), 364 | # loss_lan.item(), 365 | # loss_lap.item(), 366 | # loss_regexp.item(), 367 | # ) 368 | 369 | print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done") 370 | 371 | render_proj = sel_imgs.clone() 372 | render_proj[mask] = render_imgs[mask][..., :3].byte() 373 | 374 | exp_para[sel_ids] = sel_exp_para.clone() 375 | euler_angle[sel_ids] = sel_euler.clone() 376 | trans[sel_ids] = sel_trans.clone() 377 | light_para[sel_ids] = sel_light.clone() 378 | 379 | torch.save( 380 | { 381 | "id": id_para.detach().cpu(), 382 | "exp": exp_para.detach().cpu(), 383 | "euler": euler_angle.detach().cpu(), 384 | "trans": trans.detach().cpu(), 385 | "focal": focal_length.detach().cpu(), 386 | }, 387 | os.path.join(os.path.dirname(args.path), "track_params.pt"), 388 | ) 389 | 390 | print("params saved") 391 | -------------------------------------------------------------------------------- /data_utils/face_tracking/facemodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from util import * 6 | 7 | 8 | class Face_3DMM(nn.Module): 9 | def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num): 10 | super(Face_3DMM, self).__init__() 11 | # id_dim = 100 12 | # exp_dim = 79 13 | # tex_dim = 100 14 | self.point_num = point_num 15 | DMM_info = np.load( 16 | os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True 17 | ).item() 18 | base_id = DMM_info["b_shape"][:id_dim, :] 19 | mu_id = DMM_info["mu_shape"] 20 | base_exp = DMM_info["b_exp"][:exp_dim, :] 21 | mu_exp = DMM_info["mu_exp"] 22 | mu = mu_id + mu_exp 23 | mu = mu.reshape(-1, 3) 24 | for i in range(3): 25 | mu[:, i] -= np.mean(mu[:, i]) 26 | mu = mu.reshape(-1) 27 | self.base_id = torch.as_tensor(base_id).cuda() / 100000.0 28 | self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0 29 | self.mu = torch.as_tensor(mu).cuda() / 100000.0 30 | base_tex = DMM_info["b_tex"][:tex_dim, :] 31 | mu_tex = DMM_info["mu_tex"] 32 | self.base_tex = torch.as_tensor(base_tex).cuda() 33 | self.mu_tex = torch.as_tensor(mu_tex).cuda() 34 | sig_id = DMM_info["sig_shape"][:id_dim] 35 | sig_tex = DMM_info["sig_tex"][:tex_dim] 36 | sig_exp = DMM_info["sig_exp"][:exp_dim] 37 | self.sig_id = torch.as_tensor(sig_id).cuda() 38 | self.sig_tex = torch.as_tensor(sig_tex).cuda() 39 | self.sig_exp = torch.as_tensor(sig_exp).cuda() 40 | 41 | keys_info = np.load( 42 | os.path.join(modelpath, "keys_info.npy"), allow_pickle=True 43 | ).item() 44 | self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda() 45 | self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda() 46 | self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda() 47 | self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda() 48 | 49 | def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy): 50 | id_para = id_para * self.sig_id 51 | exp_para = exp_para * self.sig_exp 52 | batch_size = id_para.shape[0] 53 | num_per_contour = self.left_contours.shape[1] 54 | left_contours_flat = self.left_contours.reshape(-1) 55 | right_contours_flat = self.right_contours.reshape(-1) 56 | sel_index = torch.cat( 57 | ( 58 | 3 * left_contours_flat.unsqueeze(1), 59 | 3 * left_contours_flat.unsqueeze(1) + 1, 60 | 3 * left_contours_flat.unsqueeze(1) + 2, 61 | ), 62 | dim=1, 63 | ).reshape(-1) 64 | left_geometry = ( 65 | torch.mm(id_para, self.base_id[:, sel_index]) 66 | + torch.mm(exp_para, self.base_exp[:, sel_index]) 67 | + self.mu[sel_index] 68 | ) 69 | left_geometry = left_geometry.view(batch_size, -1, 3) 70 | proj_x = forward_transform( 71 | left_geometry, euler_angle, trans, focal_length, cxy 72 | )[:, :, 0] 73 | proj_x = proj_x.reshape(batch_size, 8, num_per_contour) 74 | arg_min = proj_x.argmin(dim=2) 75 | left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3) 76 | left_3dlands = left_geometry[ 77 | torch.arange(batch_size * 8), arg_min.view(-1), : 78 | ].view(batch_size, 8, 3) 79 | 80 | sel_index = torch.cat( 81 | ( 82 | 3 * right_contours_flat.unsqueeze(1), 83 | 3 * right_contours_flat.unsqueeze(1) + 1, 84 | 3 * right_contours_flat.unsqueeze(1) + 2, 85 | ), 86 | dim=1, 87 | ).reshape(-1) 88 | right_geometry = ( 89 | torch.mm(id_para, self.base_id[:, sel_index]) 90 | + torch.mm(exp_para, self.base_exp[:, sel_index]) 91 | + self.mu[sel_index] 92 | ) 93 | right_geometry = right_geometry.view(batch_size, -1, 3) 94 | proj_x = forward_transform( 95 | right_geometry, euler_angle, trans, focal_length, cxy 96 | )[:, :, 0] 97 | proj_x = proj_x.reshape(batch_size, 8, num_per_contour) 98 | arg_max = proj_x.argmax(dim=2) 99 | right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3) 100 | right_3dlands = right_geometry[ 101 | torch.arange(batch_size * 8), arg_max.view(-1), : 102 | ].view(batch_size, 8, 3) 103 | 104 | sel_index = torch.cat( 105 | ( 106 | 3 * self.keyinds.unsqueeze(1), 107 | 3 * self.keyinds.unsqueeze(1) + 1, 108 | 3 * self.keyinds.unsqueeze(1) + 2, 109 | ), 110 | dim=1, 111 | ).reshape(-1) 112 | geometry = ( 113 | torch.mm(id_para, self.base_id[:, sel_index]) 114 | + torch.mm(exp_para, self.base_exp[:, sel_index]) 115 | + self.mu[sel_index] 116 | ) 117 | lands_3d = geometry.view(-1, self.keyinds.shape[0], 3) 118 | lands_3d[:, :8, :] = left_3dlands 119 | lands_3d[:, 9:17, :] = right_3dlands 120 | return lands_3d 121 | 122 | def forward_geo_sub(self, id_para, exp_para, sub_index): 123 | id_para = id_para * self.sig_id 124 | exp_para = exp_para * self.sig_exp 125 | sel_index = torch.cat( 126 | ( 127 | 3 * sub_index.unsqueeze(1), 128 | 3 * sub_index.unsqueeze(1) + 1, 129 | 3 * sub_index.unsqueeze(1) + 2, 130 | ), 131 | dim=1, 132 | ).reshape(-1) 133 | geometry = ( 134 | torch.mm(id_para, self.base_id[:, sel_index]) 135 | + torch.mm(exp_para, self.base_exp[:, sel_index]) 136 | + self.mu[sel_index] 137 | ) 138 | return geometry.reshape(-1, sub_index.shape[0], 3) 139 | 140 | def forward_geo(self, id_para, exp_para): 141 | id_para = id_para * self.sig_id 142 | exp_para = exp_para * self.sig_exp 143 | geometry = ( 144 | torch.mm(id_para, self.base_id) 145 | + torch.mm(exp_para, self.base_exp) 146 | + self.mu 147 | ) 148 | return geometry.reshape(-1, self.point_num, 3) 149 | 150 | def forward_tex(self, tex_para): 151 | tex_para = tex_para * self.sig_tex 152 | texture = torch.mm(tex_para, self.base_tex) + self.mu_tex 153 | return texture.reshape(-1, self.point_num, 3) 154 | -------------------------------------------------------------------------------- /data_utils/face_tracking/geo_transform.py: -------------------------------------------------------------------------------- 1 | """This module contains functions for geometry transform and camera projection""" 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def euler2rot(euler_angle): 8 | batch_size = euler_angle.shape[0] 9 | theta = euler_angle[:, 0].reshape(-1, 1, 1) 10 | phi = euler_angle[:, 1].reshape(-1, 1, 1) 11 | psi = euler_angle[:, 2].reshape(-1, 1, 1) 12 | one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) 13 | zero = torch.zeros( 14 | (batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device 15 | ) 16 | rot_x = torch.cat( 17 | ( 18 | torch.cat((one, zero, zero), 1), 19 | torch.cat((zero, theta.cos(), theta.sin()), 1), 20 | torch.cat((zero, -theta.sin(), theta.cos()), 1), 21 | ), 22 | 2, 23 | ) 24 | rot_y = torch.cat( 25 | ( 26 | torch.cat((phi.cos(), zero, -phi.sin()), 1), 27 | torch.cat((zero, one, zero), 1), 28 | torch.cat((phi.sin(), zero, phi.cos()), 1), 29 | ), 30 | 2, 31 | ) 32 | rot_z = torch.cat( 33 | ( 34 | torch.cat((psi.cos(), -psi.sin(), zero), 1), 35 | torch.cat((psi.sin(), psi.cos(), zero), 1), 36 | torch.cat((zero, zero, one), 1), 37 | ), 38 | 2, 39 | ) 40 | return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) 41 | 42 | 43 | def rot_trans_geo(geometry, rot, trans): 44 | rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1) 45 | return rott_geo.permute(0, 2, 1) 46 | 47 | 48 | def euler_trans_geo(geometry, euler, trans): 49 | rot = euler2rot(euler) 50 | return rot_trans_geo(geometry, rot, trans) 51 | 52 | 53 | def proj_geo(rott_geo, camera_para): 54 | fx = camera_para[:, 0] 55 | fy = camera_para[:, 0] 56 | cx = camera_para[:, 1] 57 | cy = camera_para[:, 2] 58 | 59 | X = rott_geo[:, :, 0] 60 | Y = rott_geo[:, :, 1] 61 | Z = rott_geo[:, :, 2] 62 | 63 | fxX = fx[:, None] * X 64 | fyY = fy[:, None] * Y 65 | 66 | proj_x = -fxX / Z + cx[:, None] 67 | proj_y = fyY / Z + cy[:, None] 68 | 69 | return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) 70 | -------------------------------------------------------------------------------- /data_utils/face_tracking/render_3dmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from pytorch3d.structures import Meshes 6 | from pytorch3d.renderer import ( 7 | look_at_view_transform, 8 | PerspectiveCameras, 9 | FoVPerspectiveCameras, 10 | PointLights, 11 | DirectionalLights, 12 | Materials, 13 | RasterizationSettings, 14 | MeshRenderer, 15 | MeshRasterizer, 16 | SoftPhongShader, 17 | TexturesUV, 18 | TexturesVertex, 19 | blending, 20 | ) 21 | 22 | from pytorch3d.ops import interpolate_face_attributes 23 | 24 | from pytorch3d.renderer.blending import ( 25 | BlendParams, 26 | hard_rgb_blend, 27 | sigmoid_alpha_blend, 28 | softmax_rgb_blend, 29 | ) 30 | 31 | 32 | class SoftSimpleShader(nn.Module): 33 | """ 34 | Per pixel lighting - the lighting model is applied using the interpolated 35 | coordinates and normals for each pixel. The blending function returns the 36 | soft aggregated color using all the faces per pixel. 37 | 38 | To use the default values, simply initialize the shader with the desired 39 | device e.g. 40 | 41 | """ 42 | 43 | def __init__( 44 | self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None 45 | ): 46 | super().__init__() 47 | self.lights = lights if lights is not None else PointLights(device=device) 48 | self.materials = ( 49 | materials if materials is not None else Materials(device=device) 50 | ) 51 | self.cameras = cameras 52 | self.blend_params = blend_params if blend_params is not None else BlendParams() 53 | 54 | def to(self, device): 55 | # Manually move to device modules which are not subclasses of nn.Module 56 | self.cameras = self.cameras.to(device) 57 | self.materials = self.materials.to(device) 58 | self.lights = self.lights.to(device) 59 | return self 60 | 61 | def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: 62 | 63 | texels = meshes.sample_textures(fragments) 64 | blend_params = kwargs.get("blend_params", self.blend_params) 65 | 66 | cameras = kwargs.get("cameras", self.cameras) 67 | if cameras is None: 68 | msg = "Cameras must be specified either at initialization \ 69 | or in the forward pass of SoftPhongShader" 70 | raise ValueError(msg) 71 | znear = kwargs.get("znear", getattr(cameras, "znear", 1.0)) 72 | zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) 73 | images = softmax_rgb_blend( 74 | texels, fragments, blend_params, znear=znear, zfar=zfar 75 | ) 76 | return images 77 | 78 | 79 | class Render_3DMM(nn.Module): 80 | def __init__( 81 | self, 82 | focal=1015, 83 | img_h=500, 84 | img_w=500, 85 | batch_size=1, 86 | device=torch.device("cuda:0"), 87 | ): 88 | super(Render_3DMM, self).__init__() 89 | 90 | self.focal = focal 91 | self.img_h = img_h 92 | self.img_w = img_w 93 | self.device = device 94 | self.renderer = self.get_render(batch_size) 95 | 96 | dir_path = os.path.dirname(os.path.realpath(__file__)) 97 | topo_info = np.load( 98 | os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True 99 | ).item() 100 | self.tris = torch.as_tensor(topo_info["tris"]).to(self.device) 101 | self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device) 102 | 103 | def compute_normal(self, geometry): 104 | vert_1 = torch.index_select(geometry, 1, self.tris[:, 0]) 105 | vert_2 = torch.index_select(geometry, 1, self.tris[:, 1]) 106 | vert_3 = torch.index_select(geometry, 1, self.tris[:, 2]) 107 | nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) 108 | tri_normal = nn.functional.normalize(nnorm, dim=2) 109 | v_norm = tri_normal[:, self.vert_tris, :].sum(2) 110 | vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2) 111 | return vert_normal 112 | 113 | def get_render(self, batch_size=1): 114 | half_s = self.img_w * 0.5 115 | R, T = look_at_view_transform(10, 0, 0) 116 | R = R.repeat(batch_size, 1, 1) 117 | T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device) 118 | 119 | cameras = FoVPerspectiveCameras( 120 | device=self.device, 121 | R=R, 122 | T=T, 123 | znear=0.01, 124 | zfar=20, 125 | fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi, 126 | ) 127 | lights = PointLights( 128 | device=self.device, 129 | location=[[0.0, 0.0, 1e5]], 130 | ambient_color=[[1, 1, 1]], 131 | specular_color=[[0.0, 0.0, 0.0]], 132 | diffuse_color=[[0.0, 0.0, 0.0]], 133 | ) 134 | sigma = 1e-4 135 | raster_settings = RasterizationSettings( 136 | image_size=(self.img_h, self.img_w), 137 | blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0, 138 | faces_per_pixel=2, 139 | perspective_correct=False, 140 | ) 141 | blend_params = blending.BlendParams(background_color=[0, 0, 0]) 142 | renderer = MeshRenderer( 143 | rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras), 144 | shader=SoftSimpleShader( 145 | lights=lights, blend_params=blend_params, cameras=cameras 146 | ), 147 | ) 148 | return renderer.to(self.device) 149 | 150 | @staticmethod 151 | def Illumination_layer(face_texture, norm, gamma): 152 | 153 | n_b, num_vertex, _ = face_texture.size() 154 | n_v_full = n_b * num_vertex 155 | gamma = gamma.view(-1, 3, 9).clone() 156 | gamma[:, :, 0] += 0.8 157 | 158 | gamma = gamma.permute(0, 2, 1) 159 | 160 | a0 = np.pi 161 | a1 = 2 * np.pi / np.sqrt(3.0) 162 | a2 = 2 * np.pi / np.sqrt(8.0) 163 | c0 = 1 / np.sqrt(4 * np.pi) 164 | c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) 165 | c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) 166 | d0 = 0.5 / np.sqrt(3.0) 167 | 168 | Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0 169 | norm = norm.view(-1, 3) 170 | nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] 171 | arrH = [] 172 | 173 | arrH.append(Y0) 174 | arrH.append(-a1 * c1 * ny) 175 | arrH.append(a1 * c1 * nz) 176 | arrH.append(-a1 * c1 * nx) 177 | arrH.append(a2 * c2 * nx * ny) 178 | arrH.append(-a2 * c2 * ny * nz) 179 | arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) 180 | arrH.append(-a2 * c2 * nx * nz) 181 | arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) 182 | 183 | H = torch.stack(arrH, 1) 184 | Y = H.view(n_b, num_vertex, 9) 185 | lighting = Y.bmm(gamma) 186 | 187 | face_color = face_texture * lighting 188 | return face_color 189 | 190 | def forward(self, rott_geometry, texture, diffuse_sh): 191 | face_normal = self.compute_normal(rott_geometry) 192 | face_color = self.Illumination_layer(texture, face_normal, diffuse_sh) 193 | face_color = TexturesVertex(face_color) 194 | mesh = Meshes( 195 | rott_geometry, 196 | self.tris.float().repeat(rott_geometry.shape[0], 1, 1), 197 | face_color, 198 | ) 199 | rendered_img = self.renderer(mesh) 200 | rendered_img = torch.clamp(rendered_img, 0, 255) 201 | 202 | return rendered_img 203 | -------------------------------------------------------------------------------- /data_utils/face_tracking/render_land.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import render_util 4 | import geo_transform 5 | import numpy as np 6 | 7 | 8 | def compute_tri_normal(geometry, tris): 9 | geometry = geometry.permute(0, 2, 1) 10 | tri_1 = tris[:, 0] 11 | tri_2 = tris[:, 1] 12 | tri_3 = tris[:, 2] 13 | 14 | vert_1 = torch.index_select(geometry, 2, tri_1) 15 | vert_2 = torch.index_select(geometry, 2, tri_2) 16 | vert_3 = torch.index_select(geometry, 2, tri_3) 17 | 18 | nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1) 19 | normal = nn.functional.normalize(nnorm).permute(0, 2, 1) 20 | return normal 21 | 22 | 23 | class Compute_normal_base(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, normal): 26 | (normal_b,) = render_util.normal_base_forward(normal) 27 | ctx.save_for_backward(normal) 28 | return normal_b 29 | 30 | @staticmethod 31 | def backward(ctx, grad_normal_b): 32 | (normal,) = ctx.saved_tensors 33 | (grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal) 34 | return grad_normal 35 | 36 | 37 | class Normal_Base(torch.nn.Module): 38 | def __init__(self): 39 | super(Normal_Base, self).__init__() 40 | 41 | def forward(self, normal): 42 | return Compute_normal_base.apply(normal) 43 | 44 | 45 | def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img): 46 | point_num = geometry.shape[1] 47 | rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans) 48 | proj_geo = geo_transform.proj_geo(rott_geo, cam) 49 | rot_tri_normal = compute_tri_normal(rott_geo, tris) 50 | rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris) 51 | is_visible = -torch.bmm( 52 | rot_vert_normal.reshape(-1, 1, 3), 53 | nn.functional.normalize(rott_geo.reshape(-1, 3, 1)), 54 | ).reshape(-1, point_num) 55 | is_visible[is_visible < 0.01] = -1 56 | pixel_valid = torch.zeros( 57 | (ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]), 58 | dtype=torch.float32, 59 | device=ori_img.device, 60 | ) 61 | return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid 62 | 63 | 64 | class Render_Face(torch.autograd.Function): 65 | @staticmethod 66 | def forward( 67 | ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid 68 | ): 69 | batch_size, h, w, _ = ori_img.shape 70 | ori_img = ori_img.view(batch_size, -1, 3) 71 | ori_size = torch.cat( 72 | ( 73 | torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) 74 | * h, 75 | torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) 76 | * w, 77 | ), 78 | dim=1, 79 | ).view(-1) 80 | tri_index, tri_coord, render, real = render_util.render_face_forward( 81 | proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid 82 | ) 83 | ctx.save_for_backward( 84 | ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord 85 | ) 86 | return render, real 87 | 88 | @staticmethod 89 | def backward(ctx, grad_render, grad_real): 90 | ( 91 | ori_img, 92 | ori_size, 93 | proj_geo, 94 | texture, 95 | nbl, 96 | tri_inds, 97 | tri_index, 98 | tri_coord, 99 | ) = ctx.saved_tensors 100 | grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward( 101 | grad_render, 102 | grad_real, 103 | ori_img, 104 | ori_size, 105 | proj_geo, 106 | texture, 107 | nbl, 108 | tri_inds, 109 | tri_index, 110 | tri_coord, 111 | ) 112 | return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None 113 | 114 | 115 | class Render_RGB(nn.Module): 116 | def __init__(self): 117 | super(Render_RGB, self).__init__() 118 | 119 | def forward( 120 | self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid 121 | ): 122 | return Render_Face.apply( 123 | proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid 124 | ) 125 | 126 | 127 | def cal_land(proj_geo, is_visible, lands_info, land_num): 128 | (land_index,) = render_util.update_contour(lands_info, is_visible, land_num) 129 | proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[ 130 | :, :2 131 | ].reshape(-1, land_num, 2) 132 | return proj_land 133 | 134 | 135 | class Render_Land(nn.Module): 136 | def __init__(self): 137 | super(Render_Land, self).__init__() 138 | lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32) 139 | self.lands_info = torch.as_tensor(lands_info).cuda() 140 | tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64) 141 | self.tris = torch.as_tensor(tris).cuda() - 1 142 | vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64) 143 | self.vert_tris = torch.as_tensor(vert_tris).cuda() 144 | self.normal_baser = Normal_Base().cuda() 145 | self.renderer = Render_RGB().cuda() 146 | 147 | def render_mesh(self, geometry, euler, trans, cam, ori_img, light): 148 | batch_size, h, w, _ = ori_img.shape 149 | ori_img = ori_img.view(batch_size, -1, 3) 150 | ori_size = torch.cat( 151 | ( 152 | torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) 153 | * h, 154 | torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) 155 | * w, 156 | ), 157 | dim=1, 158 | ).view(-1) 159 | rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render( 160 | geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img 161 | ) 162 | tri_nb = self.normal_baser(rot_tri_normal.contiguous()) 163 | nbl = torch.bmm( 164 | tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3) 165 | ) 166 | texture = torch.ones_like(geometry) * 200 167 | (render,) = render_util.render_mesh( 168 | proj_geo, ori_img, ori_size, texture, nbl, self.tris 169 | ) 170 | return render.view(batch_size, h, w, 3).byte() 171 | 172 | def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands): 173 | rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render( 174 | geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img 175 | ) 176 | tri_nb = self.normal_baser(rot_tri_normal.contiguous()) 177 | nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3)) 178 | render, real = self.renderer( 179 | proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid 180 | ) 181 | proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1]) 182 | col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape( 183 | ori_img.shape[0], -1 184 | ) 185 | col_dis = torch.mean(col_minus * pixel_valid) / ( 186 | torch.mean(pixel_valid) + 0.00001 187 | ) 188 | land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape( 189 | ori_img.shape[0], -1 190 | ) 191 | lan_dis = torch.mean(land_dists) 192 | return col_dis, lan_dis 193 | -------------------------------------------------------------------------------- /data_utils/face_tracking/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def compute_tri_normal(geometry, tris): 7 | tri_1 = tris[:, 0] 8 | tri_2 = tris[:, 1] 9 | tri_3 = tris[:, 2] 10 | vert_1 = torch.index_select(geometry, 1, tri_1) 11 | vert_2 = torch.index_select(geometry, 1, tri_2) 12 | vert_3 = torch.index_select(geometry, 1, tri_3) 13 | nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) 14 | normal = nn.functional.normalize(nnorm) 15 | return normal 16 | 17 | 18 | def euler2rot(euler_angle): 19 | batch_size = euler_angle.shape[0] 20 | theta = euler_angle[:, 0].reshape(-1, 1, 1) 21 | phi = euler_angle[:, 1].reshape(-1, 1, 1) 22 | psi = euler_angle[:, 2].reshape(-1, 1, 1) 23 | one = torch.ones(batch_size, 1, 1).to(euler_angle.device) 24 | zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device) 25 | rot_x = torch.cat( 26 | ( 27 | torch.cat((one, zero, zero), 1), 28 | torch.cat((zero, theta.cos(), theta.sin()), 1), 29 | torch.cat((zero, -theta.sin(), theta.cos()), 1), 30 | ), 31 | 2, 32 | ) 33 | rot_y = torch.cat( 34 | ( 35 | torch.cat((phi.cos(), zero, -phi.sin()), 1), 36 | torch.cat((zero, one, zero), 1), 37 | torch.cat((phi.sin(), zero, phi.cos()), 1), 38 | ), 39 | 2, 40 | ) 41 | rot_z = torch.cat( 42 | ( 43 | torch.cat((psi.cos(), -psi.sin(), zero), 1), 44 | torch.cat((psi.sin(), psi.cos(), zero), 1), 45 | torch.cat((zero, zero, one), 1), 46 | ), 47 | 2, 48 | ) 49 | return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) 50 | 51 | 52 | def rot_trans_pts(geometry, rot, trans): 53 | rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None] 54 | return rott_geo.permute(0, 2, 1) 55 | 56 | 57 | def cal_lap_loss(tensor_list, weight_list): 58 | lap_kernel = ( 59 | torch.Tensor((-0.5, 1.0, -0.5)) 60 | .unsqueeze(0) 61 | .unsqueeze(0) 62 | .float() 63 | .to(tensor_list[0].device) 64 | ) 65 | loss_lap = 0 66 | for i in range(len(tensor_list)): 67 | in_tensor = tensor_list[i] 68 | in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1]) 69 | out_tensor = F.conv1d(in_tensor, lap_kernel) 70 | loss_lap += torch.mean(out_tensor ** 2) * weight_list[i] 71 | return loss_lap 72 | 73 | 74 | def proj_pts(rott_geo, focal_length, cxy): 75 | cx, cy = cxy[0], cxy[1] 76 | X = rott_geo[:, :, 0] 77 | Y = rott_geo[:, :, 1] 78 | Z = rott_geo[:, :, 2] 79 | fxX = focal_length * X 80 | fyY = focal_length * Y 81 | proj_x = -fxX / Z + cx 82 | proj_y = fyY / Z + cy 83 | return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) 84 | 85 | 86 | def forward_rott(geometry, euler_angle, trans): 87 | rot = euler2rot(euler_angle) 88 | rott_geo = rot_trans_pts(geometry, rot, trans) 89 | return rott_geo 90 | 91 | 92 | def forward_transform(geometry, euler_angle, trans, focal_length, cxy): 93 | rot = euler2rot(euler_angle) 94 | rott_geo = rot_trans_pts(geometry, rot, trans) 95 | proj_geo = proj_pts(rott_geo, focal_length, cxy) 96 | return proj_geo 97 | 98 | 99 | def cal_lan_loss(proj_lan, gt_lan): 100 | return torch.mean((proj_lan - gt_lan) ** 2) 101 | 102 | 103 | def cal_col_loss(pred_img, gt_img, img_mask): 104 | pred_img = pred_img.float() 105 | # loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255 106 | loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255 107 | loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2)) 108 | loss = torch.mean(loss) 109 | return loss 110 | -------------------------------------------------------------------------------- /data_utils/process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | import json 5 | import argparse 6 | import cv2 7 | import numpy as np 8 | 9 | def extract_audio(path, out_path, sample_rate=16000): 10 | 11 | print(f'[INFO] ===== extract audio from {path} to {out_path} =====') 12 | cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}' 13 | os.system(cmd) 14 | print(f'[INFO] ===== extracted audio =====') 15 | 16 | 17 | def extract_audio_features(path, mode='wav2vec'): 18 | 19 | print(f'[INFO] ===== extract audio labels for {path} =====') 20 | if mode == 'wav2vec': 21 | cmd = f'python nerf/asr.py --wav {path} --save_feats' 22 | else: # deepspeech 23 | cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}' 24 | os.system(cmd) 25 | print(f'[INFO] ===== extracted audio labels =====') 26 | 27 | 28 | 29 | def extract_images(path, out_path, fps=25): 30 | 31 | print(f'[INFO] ===== extract images from {path} to {out_path} =====') 32 | cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}' 33 | os.system(cmd) 34 | print(f'[INFO] ===== extracted images =====') 35 | 36 | 37 | def extract_semantics(ori_imgs_dir, parsing_dir): 38 | 39 | print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====') 40 | cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}' 41 | os.system(cmd) 42 | print(f'[INFO] ===== extracted semantics =====') 43 | 44 | 45 | def extract_landmarks(ori_imgs_dir): 46 | 47 | print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====') 48 | 49 | import face_alignment 50 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) 51 | image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) 52 | for image_path in tqdm.tqdm(image_paths): 53 | input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] 54 | input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) 55 | preds = fa.get_landmarks(input) 56 | if len(preds) > 0: 57 | lands = preds[0].reshape(-1, 2)[:,:2] 58 | np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f') 59 | del fa 60 | print(f'[INFO] ===== extracted face landmarks =====') 61 | 62 | 63 | def extract_background(base_dir, ori_imgs_dir): 64 | 65 | print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====') 66 | 67 | from sklearn.neighbors import NearestNeighbors 68 | 69 | image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) 70 | # only use 1/20 image_paths 71 | image_paths = image_paths[::20] 72 | # read one image to get H/W 73 | tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] 74 | h, w = tmp_image.shape[:2] 75 | 76 | # nearest neighbors 77 | all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() 78 | distss = [] 79 | for image_path in tqdm.tqdm(image_paths): 80 | parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) 81 | bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255) 82 | fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) 83 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) 84 | dists, _ = nbrs.kneighbors(all_xys) 85 | distss.append(dists) 86 | 87 | distss = np.stack(distss) 88 | max_dist = np.max(distss, 0) 89 | max_id = np.argmax(distss, 0) 90 | 91 | bc_pixs = max_dist > 5 92 | bc_pixs_id = np.nonzero(bc_pixs) 93 | bc_ids = max_id[bc_pixs] 94 | 95 | imgs = [] 96 | num_pixs = distss.shape[1] 97 | for image_path in image_paths: 98 | img = cv2.imread(image_path) 99 | imgs.append(img) 100 | imgs = np.stack(imgs).reshape(-1, num_pixs, 3) 101 | 102 | bc_img = np.zeros((h*w, 3), dtype=np.uint8) 103 | bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] 104 | bc_img = bc_img.reshape(h, w, 3) 105 | 106 | max_dist = max_dist.reshape(h, w) 107 | bc_pixs = max_dist > 5 108 | bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose() 109 | fg_xys = np.stack(np.nonzero(bc_pixs)).transpose() 110 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) 111 | distances, indices = nbrs.kneighbors(bg_xys) 112 | bg_fg_xys = fg_xys[indices[:, 0]] 113 | bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :] 114 | 115 | cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img) 116 | 117 | print(f'[INFO] ===== extracted background image =====') 118 | 119 | 120 | def extract_torso_and_gt(base_dir, ori_imgs_dir): 121 | 122 | print(f'[INFO] ===== extract torso and gt images for {base_dir} =====') 123 | 124 | from scipy.ndimage import binary_erosion, binary_dilation 125 | 126 | # load bg 127 | bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED) 128 | 129 | image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) 130 | 131 | for image_path in tqdm.tqdm(image_paths): 132 | # read ori image 133 | ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] 134 | 135 | # read semantics 136 | seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) 137 | head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0) 138 | neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0) 139 | torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255) 140 | bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255) 141 | 142 | # get gt image 143 | gt_image = ori_image.copy() 144 | gt_image[bg_part] = bg_image[bg_part] 145 | cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image) 146 | 147 | # get torso image 148 | torso_image = gt_image.copy() # rgb 149 | torso_image[head_part] = bg_image[head_part] 150 | torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha 151 | 152 | # torso part "vertical" in-painting... 153 | L = 8 + 1 154 | torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2] 155 | # lexsort: sort 2D coords first by y then by x, 156 | # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes 157 | inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1])) 158 | torso_coords = torso_coords[inds] 159 | # choose the top pixel for each column 160 | u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True) 161 | top_torso_coords = torso_coords[uid] # [m, 2] 162 | # only keep top-is-head pixels 163 | top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) 164 | mask = head_part[tuple(top_torso_coords_up.T)] 165 | if mask.any(): 166 | top_torso_coords = top_torso_coords[mask] 167 | # get the color 168 | top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3] 169 | # construct inpaint coords (vertically up, or minus in x) 170 | inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2] 171 | inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] 172 | inpaint_torso_coords += inpaint_offsets 173 | inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2] 174 | inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3] 175 | darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] 176 | inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] 177 | # set color 178 | torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors 179 | 180 | inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool) 181 | inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True 182 | else: 183 | inpaint_torso_mask = None 184 | 185 | 186 | # neck part "vertical" in-painting... 187 | push_down = 4 188 | L = 48 + push_down + 1 189 | 190 | neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3) 191 | 192 | neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2] 193 | # lexsort: sort 2D coords first by y then by x, 194 | # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes 195 | inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1])) 196 | neck_coords = neck_coords[inds] 197 | # choose the top pixel for each column 198 | u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True) 199 | top_neck_coords = neck_coords[uid] # [m, 2] 200 | # only keep top-is-head pixels 201 | top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0]) 202 | mask = head_part[tuple(top_neck_coords_up.T)] 203 | 204 | top_neck_coords = top_neck_coords[mask] 205 | # push these top down for 4 pixels to make the neck inpainting more natural... 206 | offset_down = np.minimum(ucnt[mask] - 1, push_down) 207 | top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1) 208 | # get the color 209 | top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3] 210 | # construct inpaint coords (vertically up, or minus in x) 211 | inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2] 212 | inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] 213 | inpaint_neck_coords += inpaint_offsets 214 | inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2] 215 | inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3] 216 | darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] 217 | inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] 218 | # set color 219 | torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors 220 | 221 | # apply blurring to the inpaint area to avoid vertical-line artifects... 222 | inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool) 223 | inpaint_mask[tuple(inpaint_neck_coords.T)] = True 224 | 225 | blur_img = torso_image.copy() 226 | blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT) 227 | 228 | torso_image[inpaint_mask] = blur_img[inpaint_mask] 229 | 230 | # set mask 231 | mask = (neck_part | torso_part | inpaint_mask) 232 | if inpaint_torso_mask is not None: 233 | mask = mask | inpaint_torso_mask 234 | torso_image[~mask] = 0 235 | torso_alpha[~mask] = 0 236 | 237 | cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1)) 238 | 239 | print(f'[INFO] ===== extracted torso and gt images =====') 240 | 241 | 242 | def face_tracking(ori_imgs_dir): 243 | 244 | print(f'[INFO] ===== perform face tracking =====') 245 | 246 | image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) 247 | 248 | # read one image to get H/W 249 | tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] 250 | h, w = tmp_image.shape[:2] 251 | 252 | cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}' 253 | 254 | os.system(cmd) 255 | 256 | print(f'[INFO] ===== finished face tracking =====') 257 | 258 | 259 | def save_transforms(base_dir, ori_imgs_dir): 260 | print(f'[INFO] ===== save transforms =====') 261 | 262 | import torch 263 | 264 | image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) 265 | 266 | # read one image to get H/W 267 | tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] 268 | h, w = tmp_image.shape[:2] 269 | 270 | params_dict = torch.load(os.path.join(base_dir, 'track_params.pt')) 271 | focal_len = params_dict['focal'] 272 | euler_angle = params_dict['euler'] 273 | trans = params_dict['trans'] / 10.0 274 | valid_num = euler_angle.shape[0] 275 | 276 | def euler2rot(euler_angle): 277 | batch_size = euler_angle.shape[0] 278 | theta = euler_angle[:, 0].reshape(-1, 1, 1) 279 | phi = euler_angle[:, 1].reshape(-1, 1, 1) 280 | psi = euler_angle[:, 2].reshape(-1, 1, 1) 281 | one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) 282 | zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) 283 | rot_x = torch.cat(( 284 | torch.cat((one, zero, zero), 1), 285 | torch.cat((zero, theta.cos(), theta.sin()), 1), 286 | torch.cat((zero, -theta.sin(), theta.cos()), 1), 287 | ), 2) 288 | rot_y = torch.cat(( 289 | torch.cat((phi.cos(), zero, -phi.sin()), 1), 290 | torch.cat((zero, one, zero), 1), 291 | torch.cat((phi.sin(), zero, phi.cos()), 1), 292 | ), 2) 293 | rot_z = torch.cat(( 294 | torch.cat((psi.cos(), -psi.sin(), zero), 1), 295 | torch.cat((psi.sin(), psi.cos(), zero), 1), 296 | torch.cat((zero, zero, one), 1) 297 | ), 2) 298 | return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) 299 | 300 | 301 | # train_val_split = int(valid_num*0.5) 302 | # train_val_split = valid_num - 25 * 20 # take the last 20s as valid set. 303 | train_val_split = int(valid_num * 10 / 11) 304 | 305 | train_ids = torch.arange(0, train_val_split) 306 | val_ids = torch.arange(train_val_split, valid_num) 307 | 308 | rot = euler2rot(euler_angle) 309 | rot_inv = rot.permute(0, 2, 1) 310 | trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2)) 311 | 312 | pose = torch.eye(4, dtype=torch.float32) 313 | save_ids = ['train', 'val'] 314 | train_val_ids = [train_ids, val_ids] 315 | mean_z = -float(torch.mean(trans[:, 2]).item()) 316 | 317 | for split in range(2): 318 | transform_dict = dict() 319 | transform_dict['focal_len'] = float(focal_len[0]) 320 | transform_dict['cx'] = float(w/2.0) 321 | transform_dict['cy'] = float(h/2.0) 322 | transform_dict['frames'] = [] 323 | ids = train_val_ids[split] 324 | save_id = save_ids[split] 325 | 326 | for i in ids: 327 | i = i.item() 328 | frame_dict = dict() 329 | frame_dict['img_id'] = i 330 | frame_dict['aud_id'] = i 331 | 332 | pose[:3, :3] = rot_inv[i] 333 | pose[:3, 3] = trans_inv[i, :, 0] 334 | 335 | frame_dict['transform_matrix'] = pose.numpy().tolist() 336 | 337 | transform_dict['frames'].append(frame_dict) 338 | 339 | with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp: 340 | json.dump(transform_dict, fp, indent=2, separators=(',', ': ')) 341 | 342 | print(f'[INFO] ===== finished saving transforms =====') 343 | 344 | 345 | if __name__ == '__main__': 346 | parser = argparse.ArgumentParser() 347 | parser.add_argument('path', type=str, help="path to video file") 348 | parser.add_argument('--task', type=int, default=-1, help="-1 means all") 349 | parser.add_argument('--asr', type=str, default='wav2vec', help="wav2vec or deepspeech") 350 | 351 | opt = parser.parse_args() 352 | 353 | base_dir = os.path.dirname(opt.path) 354 | 355 | wav_path = os.path.join(base_dir, 'aud.wav') 356 | ori_imgs_dir = os.path.join(base_dir, 'ori_imgs') 357 | parsing_dir = os.path.join(base_dir, 'parsing') 358 | gt_imgs_dir = os.path.join(base_dir, 'gt_imgs') 359 | torso_imgs_dir = os.path.join(base_dir, 'torso_imgs') 360 | 361 | os.makedirs(ori_imgs_dir, exist_ok=True) 362 | os.makedirs(parsing_dir, exist_ok=True) 363 | os.makedirs(gt_imgs_dir, exist_ok=True) 364 | os.makedirs(torso_imgs_dir, exist_ok=True) 365 | 366 | 367 | # extract audio 368 | if opt.task == -1 or opt.task == 1: 369 | extract_audio(opt.path, wav_path) 370 | 371 | # extract audio features 372 | if opt.task == -1 or opt.task == 2: 373 | extract_audio_features(wav_path, mode=opt.asr) 374 | 375 | # extract images 376 | if opt.task == -1 or opt.task == 3: 377 | extract_images(opt.path, ori_imgs_dir) 378 | 379 | # face parsing 380 | if opt.task == -1 or opt.task == 4: 381 | extract_semantics(ori_imgs_dir, parsing_dir) 382 | 383 | # extract bg 384 | if opt.task == -1 or opt.task == 5: 385 | extract_background(base_dir, ori_imgs_dir) 386 | 387 | # extract torso images and gt_images 388 | if opt.task == -1 or opt.task == 6: 389 | extract_torso_and_gt(base_dir, ori_imgs_dir) 390 | 391 | # extract face landmarks 392 | if opt.task == -1 or opt.task == 7: 393 | extract_landmarks(ori_imgs_dir) 394 | 395 | # face tracking 396 | if opt.task == -1 or opt.task == 8: 397 | face_tracking(ori_imgs_dir) 398 | 399 | # save transforms.json 400 | if opt.task == -1 or opt.task == 9: 401 | save_transforms(base_dir, ori_imgs_dir) 402 | 403 | -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_encoder(encoding, input_dim=3, 7 | multires=6, 8 | degree=4, 9 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, 10 | **kwargs): 11 | 12 | if encoding == 'None': 13 | return lambda x, **kwargs: x, input_dim 14 | 15 | elif encoding == 'frequency': 16 | from freqencoder import FreqEncoder 17 | encoder = FreqEncoder(input_dim=input_dim, degree=multires) 18 | 19 | elif encoding == 'spherical_harmonics': 20 | from shencoder import SHEncoder 21 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 22 | 23 | elif encoding == 'hashgrid': 24 | from gridencoder import GridEncoder 25 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) 26 | 27 | elif encoding == 'tiledgrid': 28 | from gridencoder import GridEncoder 29 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) 30 | 31 | elif encoding == 'ash': 32 | from ashencoder import AshEncoder 33 | encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) 34 | 35 | else: 36 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]') 37 | 38 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /freqencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .freq import FreqEncoder -------------------------------------------------------------------------------- /freqencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | '-use_fast_math' 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | _backend = load(name='_freqencoder', 33 | extra_cflags=c_flags, 34 | extra_cuda_cflags=nvcc_flags, 35 | sources=[os.path.join(_src_path, 'src', f) for f in [ 36 | 'freqencoder.cu', 37 | 'bindings.cpp', 38 | ]], 39 | ) 40 | 41 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /freqencoder/freq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _freqencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | class _freq_encoder(Function): 16 | @staticmethod 17 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 18 | def forward(ctx, inputs, degree, output_dim): 19 | # inputs: [B, input_dim], float 20 | # RETURN: [B, F], float 21 | 22 | if not inputs.is_cuda: inputs = inputs.cuda() 23 | inputs = inputs.contiguous() 24 | 25 | B, input_dim = inputs.shape # batch size, coord dim 26 | 27 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 28 | 29 | _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 30 | 31 | ctx.save_for_backward(inputs, outputs) 32 | ctx.dims = [B, input_dim, degree, output_dim] 33 | 34 | return outputs 35 | 36 | @staticmethod 37 | #@once_differentiable 38 | @custom_bwd 39 | def backward(ctx, grad): 40 | # grad: [B, C * C] 41 | 42 | grad = grad.contiguous() 43 | inputs, outputs = ctx.saved_tensors 44 | B, input_dim, degree, output_dim = ctx.dims 45 | 46 | grad_inputs = torch.zeros_like(inputs) 47 | _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 48 | 49 | return grad_inputs, None, None 50 | 51 | 52 | freq_encode = _freq_encoder.apply 53 | 54 | 55 | class FreqEncoder(nn.Module): 56 | def __init__(self, input_dim=3, degree=4): 57 | super().__init__() 58 | 59 | self.input_dim = input_dim 60 | self.degree = degree 61 | self.output_dim = input_dim + input_dim * 2 * degree 62 | 63 | def __repr__(self): 64 | return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" 65 | 66 | def forward(self, inputs, **kwargs): 67 | # inputs: [..., input_dim] 68 | # return: [..., ] 69 | 70 | prefix_shape = list(inputs.shape[:-1]) 71 | inputs = inputs.reshape(-1, self.input_dim) 72 | 73 | outputs = freq_encode(inputs, self.degree, self.output_dim) 74 | 75 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 76 | 77 | return outputs -------------------------------------------------------------------------------- /freqencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | '-use_fast_math' 11 | ] 12 | 13 | if os.name == "posix": 14 | c_flags = ['-O3', '-std=c++14'] 15 | elif os.name == "nt": 16 | c_flags = ['/O2', '/std:c++17'] 17 | 18 | # find cl.exe 19 | def find_cl_path(): 20 | import glob 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | setup( 34 | name='freqencoder', # package name, import this to use python API 35 | ext_modules=[ 36 | CUDAExtension( 37 | name='_freqencoder', # extension name, import this to use CUDA API 38 | sources=[os.path.join(_src_path, 'src', f) for f in [ 39 | 'freqencoder.cu', 40 | 'bindings.cpp', 41 | ]], 42 | extra_compile_args={ 43 | 'cxx': c_flags, 44 | 'nvcc': nvcc_flags, 45 | } 46 | ), 47 | ], 48 | cmdclass={ 49 | 'build_ext': BuildExtension, 50 | } 51 | ) -------------------------------------------------------------------------------- /freqencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "freqencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); 7 | m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 18 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 19 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 20 | 21 | inline constexpr __device__ float PI() { return 3.141592653589793f; } 22 | 23 | template 24 | __host__ __device__ T div_round_up(T val, T divisor) { 25 | return (val + divisor - 1) / divisor; 26 | } 27 | 28 | // inputs: [B, D] 29 | // outputs: [B, C], C = D + D * deg * 2 30 | __global__ void kernel_freq( 31 | const float * __restrict__ inputs, 32 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 33 | float * outputs 34 | ) { 35 | // parallel on per-element 36 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 37 | if (t >= B * C) return; 38 | 39 | // get index 40 | const uint32_t b = t / C; 41 | const uint32_t c = t - b * C; // t % C; 42 | 43 | // locate 44 | inputs += b * D; 45 | outputs += t; 46 | 47 | // write self 48 | if (c < D) { 49 | outputs[0] = inputs[c]; 50 | // write freq 51 | } else { 52 | const uint32_t col = c / D - 1; 53 | const uint32_t d = c % D; 54 | const uint32_t freq = col / 2; 55 | const float phase_shift = (col % 2) * (PI() / 2); 56 | outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); 57 | } 58 | } 59 | 60 | // grad: [B, C], C = D + D * deg * 2 61 | // outputs: [B, C] 62 | // grad_inputs: [B, D] 63 | __global__ void kernel_freq_backward( 64 | const float * __restrict__ grad, 65 | const float * __restrict__ outputs, 66 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 67 | float * grad_inputs 68 | ) { 69 | // parallel on per-element 70 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 71 | if (t >= B * D) return; 72 | 73 | const uint32_t b = t / D; 74 | const uint32_t d = t - b * D; // t % D; 75 | 76 | // locate 77 | grad += b * C; 78 | outputs += b * C; 79 | grad_inputs += t; 80 | 81 | // register 82 | float result = grad[d]; 83 | grad += D; 84 | outputs += D; 85 | 86 | for (uint32_t f = 0; f < deg; f++) { 87 | result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); 88 | grad += 2 * D; 89 | outputs += 2 * D; 90 | } 91 | 92 | // write 93 | grad_inputs[0] = result; 94 | } 95 | 96 | 97 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { 98 | CHECK_CUDA(inputs); 99 | CHECK_CUDA(outputs); 100 | 101 | CHECK_CONTIGUOUS(inputs); 102 | CHECK_CONTIGUOUS(outputs); 103 | 104 | CHECK_IS_FLOATING(inputs); 105 | CHECK_IS_FLOATING(outputs); 106 | 107 | static constexpr uint32_t N_THREADS = 128; 108 | 109 | kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); 110 | } 111 | 112 | 113 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { 114 | CHECK_CUDA(grad); 115 | CHECK_CUDA(outputs); 116 | CHECK_CUDA(grad_inputs); 117 | 118 | CHECK_CONTIGUOUS(grad); 119 | CHECK_CONTIGUOUS(outputs); 120 | CHECK_CONTIGUOUS(grad_inputs); 121 | 122 | CHECK_IS_FLOATING(grad); 123 | CHECK_IS_FLOATING(outputs); 124 | CHECK_IS_FLOATING(grad_inputs); 125 | 126 | static constexpr uint32_t N_THREADS = 128; 127 | 128 | kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); 129 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 7 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); 8 | 9 | // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 10 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gridencoder/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _gridencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | _gridtype_to_id = { 15 | 'hash': 0, 16 | 'tiled': 1, 17 | } 18 | 19 | _interp_to_id = { 20 | 'linear': 0, 21 | 'smoothstep': 1, 22 | } 23 | 24 | class _grid_encode(Function): 25 | @staticmethod 26 | @custom_fwd 27 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): 28 | # inputs: [B, D], float in [0, 1] 29 | # embeddings: [sO, C], float 30 | # offsets: [L + 1], int 31 | # RETURN: [B, F], float 32 | 33 | inputs = inputs.contiguous() 34 | 35 | B, D = inputs.shape # batch size, coord dim 36 | L = offsets.shape[0] - 1 # level 37 | C = embeddings.shape[1] # embedding dim for each level 38 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 39 | H = base_resolution # base resolution 40 | 41 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) 42 | # if C % 2 != 0, force float, since half for atomicAdd is very slow. 43 | if torch.is_autocast_enabled() and C % 2 == 0: 44 | embeddings = embeddings.to(torch.half) 45 | 46 | # L first, optimize cache for cuda kernel, but needs an extra permute later 47 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) 48 | 49 | if calc_grad_inputs: 50 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) 51 | else: 52 | dy_dx = None 53 | 54 | _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) 55 | 56 | # permute back to [B, L * C] 57 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 58 | 59 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 60 | ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] 61 | ctx.align_corners = align_corners 62 | 63 | return outputs 64 | 65 | @staticmethod 66 | #@once_differentiable 67 | @custom_bwd 68 | def backward(ctx, grad): 69 | 70 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 71 | B, D, C, L, S, H, gridtype, interpolation = ctx.dims 72 | align_corners = ctx.align_corners 73 | 74 | # grad: [B, L * C] --> [L, B, C] 75 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 76 | 77 | grad_embeddings = torch.zeros_like(embeddings) 78 | 79 | if dy_dx is not None: 80 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 81 | else: 82 | grad_inputs = None 83 | 84 | _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) 85 | 86 | if dy_dx is not None: 87 | grad_inputs = grad_inputs.to(inputs.dtype) 88 | 89 | return grad_inputs, grad_embeddings, None, None, None, None, None, None, None 90 | 91 | 92 | 93 | grid_encode = _grid_encode.apply 94 | 95 | 96 | class GridEncoder(nn.Module): 97 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): 98 | super().__init__() 99 | 100 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 101 | if desired_resolution is not None: 102 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 103 | 104 | self.input_dim = input_dim # coord dims, 2 or 3 105 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 106 | self.level_dim = level_dim # encode channels per level 107 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 108 | self.log2_hashmap_size = log2_hashmap_size 109 | self.base_resolution = base_resolution 110 | self.output_dim = num_levels * level_dim 111 | self.gridtype = gridtype 112 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" 113 | self.interpolation = interpolation 114 | self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" 115 | self.align_corners = align_corners 116 | 117 | # allocate parameters 118 | offsets = [] 119 | offset = 0 120 | self.max_params = 2 ** log2_hashmap_size 121 | for i in range(num_levels): 122 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 123 | params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number 124 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 125 | offsets.append(offset) 126 | offset += params_in_level 127 | offsets.append(offset) 128 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 129 | self.register_buffer('offsets', offsets) 130 | 131 | self.n_params = offsets[-1] * level_dim 132 | 133 | # parameters 134 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 135 | 136 | self.reset_parameters() 137 | 138 | def reset_parameters(self): 139 | std = 1e-4 140 | self.embeddings.data.uniform_(-std, std) 141 | 142 | def __repr__(self): 143 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" 144 | 145 | def forward(self, inputs, bound=1): 146 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] 147 | # return: [..., num_levels * level_dim] 148 | 149 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 150 | 151 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 152 | 153 | prefix_shape = list(inputs.shape[:-1]) 154 | inputs = inputs.view(-1, self.input_dim) 155 | 156 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) 157 | outputs = outputs.view(prefix_shape + [self.output_dim]) 158 | 159 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 160 | 161 | return outputs 162 | 163 | # always run in float precision! 164 | @torch.cuda.amp.autocast(enabled=False) 165 | def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): 166 | # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. 167 | 168 | D = self.input_dim 169 | C = self.embeddings.shape[1] # embedding dim for each level 170 | L = self.offsets.shape[0] - 1 # level 171 | S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 172 | H = self.base_resolution # base resolution 173 | 174 | if inputs is None: 175 | # randomized in [0, 1] 176 | inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) 177 | else: 178 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 179 | inputs = inputs.view(-1, self.input_dim) 180 | B = inputs.shape[0] 181 | 182 | if self.embeddings.grad is None: 183 | raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') 184 | 185 | _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='gridencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_gridencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'gridencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); 9 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 14 | 15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); 16 | 17 | #endif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from nerf.provider import NeRFDataset 5 | from nerf.gui import NeRFGUI 6 | from nerf.utils import * 7 | 8 | # torch.autograd.set_detect_anomaly(True) 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('path', type=str) 14 | parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye") 15 | parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)") 16 | parser.add_argument('--test_train', action='store_true', help="test mode (load model and train dataset)") 17 | parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use") 18 | parser.add_argument('--workspace', type=str, default='workspace') 19 | parser.add_argument('--seed', type=int, default=0) 20 | 21 | ### training options 22 | parser.add_argument('--iters', type=int, default=200000, help="training iters") 23 | parser.add_argument('--lr', type=float, default=5e-3, help="initial learning rate") 24 | parser.add_argument('--lr_net', type=float, default=5e-4, help="initial learning rate") 25 | parser.add_argument('--ckpt', type=str, default='latest') 26 | parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") 27 | parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") 28 | parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") 29 | parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") 30 | parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") 31 | parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") 32 | parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") 33 | 34 | 35 | ### network backbone options 36 | parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") 37 | 38 | parser.add_argument('--lambda_amb', type=float, default=0.1, help="lambda for ambient loss") 39 | 40 | parser.add_argument('--bg_img', type=str, default='', help="background image") 41 | parser.add_argument('--fbg', action='store_true', help="frame-wise bg") 42 | parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") 43 | parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") 44 | parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") 45 | 46 | parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") 47 | 48 | ### dataset options 49 | parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") 50 | parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") 51 | # (the default value is for the fox dataset) 52 | parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") 53 | parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") 54 | parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") 55 | parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") 56 | parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") 57 | parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") 58 | parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") 59 | parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") 60 | 61 | parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") 62 | parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...") 63 | 64 | parser.add_argument('--torso', action='store_true', help="fix head and train torso") 65 | parser.add_argument('--head_ckpt', type=str, default='', help="head model") 66 | 67 | ### GUI options 68 | parser.add_argument('--gui', action='store_true', help="start a GUI") 69 | parser.add_argument('--W', type=int, default=450, help="GUI width") 70 | parser.add_argument('--H', type=int, default=450, help="GUI height") 71 | parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center") 72 | parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy") 73 | parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") 74 | 75 | ### else 76 | parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") 77 | parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") 78 | parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") 79 | 80 | parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") 81 | parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") 82 | 83 | parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") 84 | 85 | parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension") 86 | parser.add_argument('--part', action='store_true', help="use partial training data (1/10)") 87 | parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") 88 | 89 | parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") 90 | parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") 91 | parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") 92 | 93 | # asr 94 | parser.add_argument('--asr', action='store_true', help="load asr for real-time app") 95 | parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") 96 | parser.add_argument('--asr_play', action='store_true', help="play out the audio") 97 | 98 | parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') 99 | # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') 100 | 101 | parser.add_argument('--asr_save_feats', action='store_true') 102 | # audio FPS 103 | parser.add_argument('--fps', type=int, default=50) 104 | # sliding window left-middle-right length (unit: 20ms) 105 | parser.add_argument('-l', type=int, default=10) 106 | parser.add_argument('-m', type=int, default=50) 107 | parser.add_argument('-r', type=int, default=10) 108 | 109 | opt = parser.parse_args() 110 | 111 | if opt.O: 112 | opt.fp16 = True 113 | opt.exp_eye = True 114 | 115 | if opt.test: 116 | opt.smooth_path = True 117 | opt.smooth_eye = True 118 | opt.smooth_lips = True 119 | 120 | opt.cuda_ray = True 121 | # assert opt.cuda_ray, "Only support CUDA ray mode." 122 | 123 | if opt.patch_size > 1: 124 | # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." 125 | assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." 126 | 127 | if opt.finetune_lips: 128 | # do not update density grid in finetune stage 129 | opt.update_extra_interval = 1e9 130 | 131 | from nerf.network import NeRFNetwork 132 | 133 | print(opt) 134 | 135 | seed_everything(opt.seed) 136 | 137 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 138 | 139 | model = NeRFNetwork(opt) 140 | 141 | # manually load state dict for head 142 | if opt.torso and opt.head_ckpt != '': 143 | 144 | model_dict = torch.load(opt.head_ckpt, map_location='cpu')['model'] 145 | 146 | missing_keys, unexpected_keys = model.load_state_dict(model_dict, strict=False) 147 | 148 | if len(missing_keys) > 0: 149 | print(f"[WARN] missing keys: {missing_keys}") 150 | if len(unexpected_keys) > 0: 151 | print(f"[WARN] unexpected keys: {unexpected_keys}") 152 | 153 | # freeze these keys 154 | for k, v in model.named_parameters(): 155 | if k in model_dict: 156 | # print(f'[INFO] freeze {k}, {v.shape}') 157 | v.requires_grad = False 158 | 159 | 160 | # print(model) 161 | 162 | criterion = torch.nn.MSELoss(reduction='none') 163 | 164 | if opt.test: 165 | 166 | if opt.gui: 167 | metrics = [] # use no metric in GUI for faster initialization... 168 | else: 169 | # metrics = [PSNRMeter(), LPIPSMeter(device=device)] 170 | metrics = [PSNRMeter(), LPIPSMeter(device=device), LMDMeter(backend='fan')] 171 | 172 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) 173 | 174 | if opt.test_train: 175 | test_set = NeRFDataset(opt, device=device, type='train') 176 | # a manual fix to test on the training dataset 177 | test_set.training = False 178 | test_set.num_rays = -1 179 | test_loader = test_set.dataloader() 180 | else: 181 | test_loader = NeRFDataset(opt, device=device, type='test').dataloader() 182 | 183 | 184 | # temp fix: for update_extra_states 185 | model.aud_features = test_loader._data.auds 186 | model.eye_areas = test_loader._data.eye_area 187 | 188 | if opt.gui: 189 | # we still need test_loader to provide audio features for testing. 190 | with NeRFGUI(opt, trainer, test_loader) as gui: 191 | gui.render() 192 | 193 | else: 194 | 195 | ### evaluate metrics (slow) 196 | if test_loader.has_gt: 197 | trainer.evaluate(test_loader) 198 | 199 | ### test and save video (fast) 200 | trainer.test(test_loader) 201 | 202 | else: 203 | 204 | optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr, opt.lr_net), betas=(0.9, 0.99), eps=1e-15) 205 | 206 | train_loader = NeRFDataset(opt, device=device, type='train').dataloader() 207 | 208 | assert len(train_loader) < opt.ind_num, f"[ERROR] dataset too many frames: {len(train_loader)}, please increase --ind_num to this number!" 209 | 210 | # temp fix: for update_extra_states 211 | model.aud_features = train_loader._data.auds 212 | model.eye_area = train_loader._data.eye_area 213 | model.poses = train_loader._data.poses 214 | 215 | # decay to 0.1 * init_lr at last iter step 216 | if opt.finetune_lips: 217 | scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.05 ** (iter / opt.iters)) 218 | else: 219 | scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** (iter / opt.iters)) 220 | 221 | metrics = [PSNRMeter(), LPIPSMeter(device=device)] 222 | 223 | eval_interval = max(1, int(5000 / len(train_loader))) 224 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=eval_interval) 225 | 226 | if opt.gui: 227 | with NeRFGUI(opt, trainer, train_loader) as gui: 228 | gui.render() 229 | 230 | else: 231 | valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader() 232 | 233 | max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) 234 | print(f'[INFO] max_epoch = {max_epoch}') 235 | trainer.train(train_loader, valid_loader, max_epoch) 236 | 237 | # free some mem 238 | del train_loader, valid_loader 239 | torch.cuda.empty_cache() 240 | 241 | # also test 242 | test_loader = NeRFDataset(opt, device=device, type='test').dataloader() 243 | 244 | if test_loader.has_gt: 245 | trainer.evaluate(test_loader) # blender has gt, so evaluate it. 246 | 247 | trainer.test(test_loader) -------------------------------------------------------------------------------- /nerf/asr.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from transformers import AutoModelForCTC, AutoProcessor 6 | 7 | import pyaudio 8 | import soundfile as sf 9 | import resampy 10 | 11 | from queue import Queue 12 | from threading import Thread, Event 13 | 14 | 15 | def _read_frame(stream, exit_event, queue, chunk): 16 | 17 | while True: 18 | if exit_event.is_set(): 19 | print(f'[INFO] read frame thread ends') 20 | break 21 | frame = stream.read(chunk, exception_on_overflow=False) 22 | frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk] 23 | queue.put(frame) 24 | 25 | def _play_frame(stream, exit_event, queue, chunk): 26 | 27 | while True: 28 | if exit_event.is_set(): 29 | print(f'[INFO] play frame thread ends') 30 | break 31 | frame = queue.get() 32 | frame = (frame * 32767).astype(np.int16).tobytes() 33 | stream.write(frame, chunk) 34 | 35 | class ASR: 36 | def __init__(self, opt): 37 | 38 | self.opt = opt 39 | 40 | self.play = opt.asr_play 41 | 42 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | self.fps = opt.fps # 20 ms per frame 44 | self.sample_rate = 16000 45 | self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) 46 | self.mode = 'live' if opt.asr_wav == '' else 'file' 47 | 48 | if 'esperanto' in self.opt.asr_model: 49 | self.audio_dim = 44 50 | elif 'deepspeech' in self.opt.asr_model: 51 | self.audio_dim = 29 52 | else: 53 | self.audio_dim = 32 54 | 55 | # prepare context cache 56 | # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms 57 | self.context_size = opt.m 58 | self.stride_left_size = opt.l 59 | self.stride_right_size = opt.r 60 | self.text = '[START]\n' 61 | self.terminated = False 62 | self.frames = [] 63 | 64 | # pad left frames 65 | if self.stride_left_size > 0: 66 | self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) 67 | 68 | 69 | self.exit_event = Event() 70 | self.audio_instance = pyaudio.PyAudio() 71 | 72 | # create input stream 73 | if self.mode == 'file': 74 | self.file_stream = self.create_file_stream() 75 | else: 76 | # start a background process to read frames 77 | self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) 78 | self.queue = Queue() 79 | self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) 80 | 81 | # play out the audio too...? 82 | if self.play: 83 | self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk) 84 | self.output_queue = Queue() 85 | self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk)) 86 | 87 | # current location of audio 88 | self.idx = 0 89 | 90 | # create wav2vec model 91 | print(f'[INFO] loading ASR model {self.opt.asr_model}...') 92 | self.processor = AutoProcessor.from_pretrained(opt.asr_model) 93 | self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) 94 | 95 | # prepare to save logits 96 | if self.opt.asr_save_feats: 97 | self.all_feats = [] 98 | 99 | # the extracted features 100 | # use a loop queue to efficiently record endless features: [f--t---][-------][-------] 101 | self.feat_buffer_size = 4 102 | self.feat_buffer_idx = 0 103 | self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device) 104 | 105 | # TODO: hard coded 16 and 8 window size... 106 | self.front = self.feat_buffer_size * self.context_size - 8 # fake padding 107 | self.tail = 8 108 | # attention window... 109 | self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... 110 | 111 | # warm up steps needed: mid + right + window_size + attention_size 112 | self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3 113 | 114 | self.listening = False 115 | self.playing = False 116 | 117 | def listen(self): 118 | # start 119 | if self.mode == 'live' and not self.listening: 120 | print(f'[INFO] starting read frame thread...') 121 | self.process_read_frame.start() 122 | self.listening = True 123 | 124 | if self.play and not self.playing: 125 | print(f'[INFO] starting play frame thread...') 126 | self.process_play_frame.start() 127 | self.playing = True 128 | 129 | def stop(self): 130 | 131 | self.exit_event.set() 132 | 133 | if self.play: 134 | self.output_stream.stop_stream() 135 | self.output_stream.close() 136 | if self.playing: 137 | self.process_play_frame.join() 138 | self.playing = False 139 | 140 | if self.mode == 'live': 141 | self.input_stream.stop_stream() 142 | self.input_stream.close() 143 | if self.listening: 144 | self.process_read_frame.join() 145 | self.listening = False 146 | 147 | 148 | def __enter__(self): 149 | return self 150 | 151 | def __exit__(self, exc_type, exc_value, traceback): 152 | 153 | self.stop() 154 | 155 | if self.mode == 'live': 156 | # live mode: also print the result text. 157 | self.text += '\n[END]' 158 | print(self.text) 159 | 160 | def get_next_feat(self): 161 | # return a [1/8, 16] window, for the next input to nerf side. 162 | 163 | while len(self.att_feats) < 8: 164 | # [------f+++t-----] 165 | if self.front < self.tail: 166 | feat = self.feat_queue[self.front:self.tail] 167 | # [++t-----------f+] 168 | else: 169 | feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) 170 | 171 | self.front = (self.front + 2) % self.feat_queue.shape[0] 172 | self.tail = (self.tail + 2) % self.feat_queue.shape[0] 173 | 174 | # print(self.front, self.tail, feat.shape) 175 | 176 | self.att_feats.append(feat.permute(1, 0)) 177 | 178 | att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] 179 | 180 | # discard old 181 | self.att_feats = self.att_feats[1:] 182 | 183 | return att_feat 184 | 185 | def run_step(self): 186 | 187 | if self.terminated: 188 | return 189 | 190 | # get a frame of audio 191 | frame = self.get_audio_frame() 192 | 193 | # the last frame 194 | if frame is None: 195 | # terminate, but always run the network for the left frames 196 | self.terminated = True 197 | else: 198 | self.frames.append(frame) 199 | # put to output 200 | if self.play: 201 | self.output_queue.put(frame) 202 | # context not enough, do not run network. 203 | if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: 204 | return 205 | 206 | inputs = np.concatenate(self.frames) # [N * chunk] 207 | 208 | # discard the old part to save memory 209 | if not self.terminated: 210 | self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] 211 | 212 | logits, labels, text = self.frame_to_text(inputs) 213 | feats = logits # better lips-sync than labels 214 | 215 | # save feats 216 | if self.opt.asr_save_feats: 217 | self.all_feats.append(feats) 218 | 219 | # record the feats efficiently.. (no concat, constant memory) 220 | if not self.terminated: 221 | start = self.feat_buffer_idx * self.context_size 222 | end = start + feats.shape[0] 223 | self.feat_queue[start:end] = feats 224 | self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size 225 | 226 | # very naive, just concat the text output. 227 | if text != '': 228 | self.text = self.text + ' ' + text 229 | 230 | # will only run once at ternimation 231 | if self.terminated: 232 | self.text += '\n[END]' 233 | print(self.text) 234 | if self.opt.asr_save_feats: 235 | print(f'[INFO] save all feats for training purpose... ') 236 | feats = torch.cat(self.all_feats, dim=0) # [N, C] 237 | # print('[INFO] before unfold', feats.shape) 238 | window_size = 16 239 | padding = window_size // 2 240 | feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] 241 | feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] 242 | unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] 243 | unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] 244 | # print('[INFO] after unfold', unfold_feats.shape) 245 | # save to a npy file 246 | if 'esperanto' in self.opt.asr_model: 247 | output_path = self.opt.asr_wav.replace('.wav', '_eo.npy') 248 | else: 249 | output_path = self.opt.asr_wav.replace('.wav', '.npy') 250 | np.save(output_path, unfold_feats.cpu().numpy()) 251 | print(f"[INFO] saved logits to {output_path}") 252 | 253 | def create_file_stream(self): 254 | 255 | stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 256 | stream = stream.astype(np.float32) 257 | 258 | if stream.ndim > 1: 259 | print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') 260 | stream = stream[:, 0] 261 | 262 | if sample_rate != self.sample_rate: 263 | print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') 264 | stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) 265 | 266 | print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') 267 | 268 | return stream 269 | 270 | 271 | def create_pyaudio_stream(self): 272 | 273 | import pyaudio 274 | 275 | print(f'[INFO] creating live audio stream ...') 276 | 277 | audio = pyaudio.PyAudio() 278 | 279 | # get devices 280 | info = audio.get_host_api_info_by_index(0) 281 | n_devices = info.get('deviceCount') 282 | 283 | for i in range(0, n_devices): 284 | if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: 285 | name = audio.get_device_info_by_host_api_device_index(0, i).get('name') 286 | print(f'[INFO] choose audio device {name}, id {i}') 287 | break 288 | 289 | # get stream 290 | stream = audio.open(input_device_index=i, 291 | format=pyaudio.paInt16, 292 | channels=1, 293 | rate=self.sample_rate, 294 | input=True, 295 | frames_per_buffer=self.chunk) 296 | 297 | return audio, stream 298 | 299 | 300 | def get_audio_frame(self): 301 | 302 | if self.mode == 'file': 303 | 304 | if self.idx < self.file_stream.shape[0]: 305 | frame = self.file_stream[self.idx: self.idx + self.chunk] 306 | self.idx = self.idx + self.chunk 307 | return frame 308 | else: 309 | return None 310 | 311 | else: 312 | 313 | frame = self.queue.get() 314 | # print(f'[INFO] get frame {frame.shape}') 315 | 316 | self.idx = self.idx + self.chunk 317 | 318 | return frame 319 | 320 | 321 | def frame_to_text(self, frame): 322 | # frame: [N * 320], N = (context_size + 2 * stride_size) 323 | 324 | inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) 325 | 326 | with torch.no_grad(): 327 | result = self.model(inputs.input_values.to(self.device)) 328 | logits = result.logits # [1, N - 1, 32] 329 | 330 | # cut off stride 331 | left = max(0, self.stride_left_size) 332 | right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. 333 | 334 | # do not cut right if terminated. 335 | if self.terminated: 336 | right = logits.shape[1] 337 | 338 | logits = logits[:, left:right] 339 | 340 | # print(frame.shape, inputs.input_values.shape, logits.shape) 341 | 342 | predicted_ids = torch.argmax(logits, dim=-1) 343 | transcription = self.processor.batch_decode(predicted_ids)[0].lower() 344 | 345 | 346 | # for esperanto 347 | # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) 348 | 349 | # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) 350 | # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) 351 | # print(predicted_ids[0]) 352 | # print(transcription) 353 | 354 | return logits[0], predicted_ids[0], transcription # [N,] 355 | 356 | 357 | def run(self): 358 | 359 | self.listen() 360 | 361 | while not self.terminated: 362 | self.run_step() 363 | 364 | def clear_queue(self): 365 | # clear the queue, to reduce potential latency... 366 | print(f'[INFO] clear queue') 367 | if self.mode == 'live': 368 | self.queue.queue.clear() 369 | if self.play: 370 | self.output_queue.queue.clear() 371 | 372 | def warm_up(self): 373 | 374 | self.listen() 375 | 376 | print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') 377 | t = time.time() 378 | for _ in range(self.warm_up_steps): 379 | self.run_step() 380 | if torch.cuda.is_available(): 381 | torch.cuda.synchronize() 382 | t = time.time() - t 383 | print(f'[INFO] warm-up done, actual latency = {t:.6f}s') 384 | 385 | self.clear_queue() 386 | 387 | 388 | 389 | 390 | if __name__ == '__main__': 391 | import argparse 392 | 393 | parser = argparse.ArgumentParser() 394 | parser.add_argument('--wav', type=str, default='') 395 | parser.add_argument('--play', action='store_true', help="play out the audio") 396 | 397 | parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') 398 | # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') 399 | 400 | parser.add_argument('--save_feats', action='store_true') 401 | # audio FPS 402 | parser.add_argument('--fps', type=int, default=50) 403 | # sliding window left-middle-right length. 404 | parser.add_argument('-l', type=int, default=10) 405 | parser.add_argument('-m', type=int, default=50) 406 | parser.add_argument('-r', type=int, default=10) 407 | 408 | opt = parser.parse_args() 409 | 410 | # fix 411 | opt.asr_wav = opt.wav 412 | opt.asr_play = opt.play 413 | opt.asr_model = opt.model 414 | opt.asr_save_feats = opt.save_feats 415 | 416 | if 'deepspeech' in opt.asr_model: 417 | raise ValueError("DeepSpeech features should not use this code to extract...") 418 | 419 | with ASR(opt) as asr: 420 | asr.run() -------------------------------------------------------------------------------- /nerf/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from .renderer import NeRFRenderer 8 | 9 | # Audio feature extractor 10 | class AudioAttNet(nn.Module): 11 | def __init__(self, dim_aud=64, seq_len=8): 12 | super(AudioAttNet, self).__init__() 13 | self.seq_len = seq_len 14 | self.dim_aud = dim_aud 15 | self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len 16 | nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True), 17 | nn.LeakyReLU(0.02, True), 18 | nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), 19 | nn.LeakyReLU(0.02, True), 20 | nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), 21 | nn.LeakyReLU(0.02, True), 22 | nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), 23 | nn.LeakyReLU(0.02, True), 24 | nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), 25 | nn.LeakyReLU(0.02, True) 26 | ) 27 | self.attentionNet = nn.Sequential( 28 | nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), 29 | nn.Softmax(dim=1) 30 | ) 31 | 32 | def forward(self, x): 33 | # x: [1, seq_len, dim_aud] 34 | y = x.permute(0, 2, 1) # [1, dim_aud, seq_len] 35 | y = self.attentionConvNet(y) 36 | y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1) 37 | return torch.sum(y * x, dim=1) # [1, dim_aud] 38 | 39 | 40 | # Audio feature extractor 41 | class AudioNet(nn.Module): 42 | def __init__(self, dim_in=29, dim_aud=64, win_size=16): 43 | super(AudioNet, self).__init__() 44 | self.win_size = win_size 45 | self.dim_aud = dim_aud 46 | self.encoder_conv = nn.Sequential( # n x 29 x 16 47 | nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 8 48 | nn.LeakyReLU(0.02, True), 49 | nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 4 50 | nn.LeakyReLU(0.02, True), 51 | nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 2 52 | nn.LeakyReLU(0.02, True), 53 | nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 1 54 | nn.LeakyReLU(0.02, True), 55 | ) 56 | self.encoder_fc1 = nn.Sequential( 57 | nn.Linear(64, 64), 58 | nn.LeakyReLU(0.02, True), 59 | nn.Linear(64, dim_aud), 60 | ) 61 | 62 | def forward(self, x): 63 | half_w = int(self.win_size/2) 64 | x = x[:, :, 8-half_w:8+half_w] 65 | x = self.encoder_conv(x).squeeze(-1) 66 | x = self.encoder_fc1(x) 67 | return x 68 | 69 | class MLP(nn.Module): 70 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers): 71 | super().__init__() 72 | self.dim_in = dim_in 73 | self.dim_out = dim_out 74 | self.dim_hidden = dim_hidden 75 | self.num_layers = num_layers 76 | 77 | net = [] 78 | for l in range(num_layers): 79 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False)) 80 | 81 | self.net = nn.ModuleList(net) 82 | 83 | def forward(self, x): 84 | for l in range(self.num_layers): 85 | x = self.net[l](x) 86 | if l != self.num_layers - 1: 87 | x = F.relu(x, inplace=True) 88 | return x 89 | 90 | 91 | class NeRFNetwork(NeRFRenderer): 92 | def __init__(self, 93 | opt, 94 | # main network 95 | num_layers=3, 96 | hidden_dim=64, 97 | geo_feat_dim=64, 98 | num_layers_color=2, 99 | hidden_dim_color=64, 100 | # audio pre-encoder 101 | audio_dim=64, 102 | # deform_ambient net 103 | num_layers_ambient=3, 104 | hidden_dim_ambient=64, 105 | # ambient net 106 | ambient_dim=2, 107 | # torso net (hard coded for now) 108 | ): 109 | super().__init__(opt) 110 | 111 | # audio embedding 112 | self.emb = self.opt.emb 113 | 114 | if 'esperanto' in self.opt.asr_model: 115 | self.audio_in_dim = 44 116 | elif 'deepspeech' in self.opt.asr_model: 117 | self.audio_in_dim = 29 118 | else: 119 | self.audio_in_dim = 32 120 | 121 | if self.emb: 122 | self.embedding = nn.Embedding(self.audio_in_dim, self.audio_in_dim) 123 | 124 | # audio network 125 | self.audio_dim = audio_dim 126 | self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) 127 | 128 | self.att = self.opt.att 129 | if self.att > 0: 130 | self.audio_att_net = AudioAttNet(self.audio_dim) 131 | 132 | # ambient network 133 | self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048 * self.bound, interpolation='linear') 134 | self.encoder_ambient, self.in_dim_ambient = get_encoder('tiledgrid', input_dim=ambient_dim, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048, interpolation='linear') 135 | 136 | self.num_layers_ambient = num_layers_ambient 137 | self.hidden_dim_ambient = hidden_dim_ambient 138 | self.ambient_dim = ambient_dim 139 | 140 | self.ambient_net = MLP(self.in_dim + self.audio_dim, self.ambient_dim, self.hidden_dim_ambient, self.num_layers_ambient) 141 | 142 | # sigma network 143 | self.num_layers = num_layers 144 | self.hidden_dim = hidden_dim 145 | self.geo_feat_dim = geo_feat_dim 146 | 147 | self.eye_dim = 1 if self.exp_eye else 0 148 | 149 | self.sigma_net = MLP(self.in_dim + self.in_dim_ambient + self.eye_dim, 1 + self.geo_feat_dim, self.hidden_dim, self.num_layers) 150 | 151 | # color network 152 | self.num_layers_color = num_layers_color 153 | self.hidden_dim_color = hidden_dim_color 154 | self.encoder_dir, self.in_dim_dir = get_encoder('spherical_harmonics') 155 | 156 | self.color_net = MLP(self.in_dim_dir + self.geo_feat_dim + self.individual_dim, 3, self.hidden_dim_color, self.num_layers_color) 157 | 158 | if self.torso: 159 | # torso deform network 160 | self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('frequency', input_dim=2, multires=10) 161 | self.pose_encoder, self.pose_in_dim = get_encoder('frequency', input_dim=6, multires=4) 162 | self.torso_deform_net = MLP(self.torso_deform_in_dim + self.pose_in_dim + self.individual_dim_torso, 2, 64, 3) 163 | 164 | # torso color network 165 | self.torso_encoder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048, interpolation='linear') 166 | # self.torso_net = MLP(self.torso_in_dim + self.torso_deform_in_dim + self.pose_in_dim + self.individual_dim_torso + self.audio_dim, 4, 64, 3) 167 | self.torso_net = MLP(self.torso_in_dim + self.torso_deform_in_dim + self.pose_in_dim + self.individual_dim_torso, 4, 32, 3) 168 | 169 | 170 | def encode_audio(self, a): 171 | # a: [1, 29, 16] or [8, 29, 16], audio features from deepspeech 172 | # if emb, a should be: [1, 16] or [8, 16] 173 | 174 | # fix audio traininig 175 | if a is None: return None 176 | 177 | if self.emb: 178 | a = self.embedding(a).transpose(-1, -2).contiguous() # [1/8, 29, 16] 179 | 180 | enc_a = self.audio_net(a) # [1/8, 64] 181 | 182 | if self.att > 0: 183 | enc_a = self.audio_att_net(enc_a.unsqueeze(0)) # [1, 64] 184 | 185 | return enc_a 186 | 187 | 188 | def forward_torso(self, x, poses, enc_a, c=None): 189 | # x: [N, 2] in [-1, 1] 190 | # head poses: [1, 6] 191 | # c: [1, ind_dim], individual code 192 | 193 | # test: shrink x 194 | x = x * self.opt.torso_shrink 195 | 196 | # deformation-based 197 | enc_pose = self.pose_encoder(poses) 198 | enc_x = self.torso_deform_encoder(x) 199 | 200 | if c is not None: 201 | h = torch.cat([enc_x, enc_pose.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1) 202 | else: 203 | h = torch.cat([enc_x, enc_pose.repeat(x.shape[0], 1)], dim=-1) 204 | 205 | dx = self.torso_deform_net(h) 206 | 207 | x = (x + dx).clamp(-1, 1) 208 | 209 | x = self.torso_encoder(x, bound=1) 210 | 211 | # h = torch.cat([x, h, enc_a.repeat(x.shape[0], 1)], dim=-1) 212 | h = torch.cat([x, h], dim=-1) 213 | 214 | h = self.torso_net(h) 215 | 216 | alpha = torch.sigmoid(h[..., :1]) 217 | color = torch.sigmoid(h[..., 1:]) 218 | 219 | return alpha, color, dx 220 | 221 | 222 | def forward(self, x, d, enc_a, c, e=None): 223 | # x: [N, 3], in [-bound, bound] 224 | # d: [N, 3], nomalized in [-1, 1] 225 | # enc_a: [1, aud_dim] 226 | # c: [1, ind_dim], individual code 227 | # e: [1, 1], eye feature 228 | 229 | # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 230 | # starter.record() 231 | 232 | if enc_a is None: 233 | ambient = torch.zeros_like(x[:, :self.ambient_dim]) 234 | enc_x = self.encoder(x, bound=self.bound) 235 | enc_w = self.encoder_ambient(ambient, bound=1) 236 | else: 237 | 238 | enc_a = enc_a.repeat(x.shape[0], 1) 239 | enc_x = self.encoder(x, bound=self.bound) 240 | 241 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"enocoder_deform = {curr_time}"); starter.record() 242 | 243 | # ambient 244 | ambient = torch.cat([enc_x, enc_a], dim=1) 245 | ambient = self.ambient_net(ambient).float() 246 | ambient = torch.tanh(ambient) # map to [-1, 1] 247 | 248 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"de-an net = {curr_time}"); starter.record() 249 | 250 | # sigma 251 | enc_w = self.encoder_ambient(ambient, bound=1) 252 | 253 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"encoder = {curr_time}"); starter.record() 254 | 255 | if e is not None: 256 | h = torch.cat([enc_x, enc_w, e.repeat(x.shape[0], 1)], dim=-1) 257 | else: 258 | h = torch.cat([enc_x, enc_w], dim=-1) 259 | 260 | h = self.sigma_net(h) 261 | 262 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"sigma_net = {curr_time}"); starter.record() 263 | sigma = trunc_exp(h[..., 0]) 264 | geo_feat = h[..., 1:] 265 | 266 | # color 267 | enc_d = self.encoder_dir(d) 268 | 269 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"encoder_dir = {curr_time}"); starter.record() 270 | 271 | if c is not None: 272 | h = torch.cat([enc_d, geo_feat, c.repeat(x.shape[0], 1)], dim=-1) 273 | else: 274 | h = torch.cat([enc_d, geo_feat], dim=-1) 275 | 276 | h = self.color_net(h) 277 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"color_net = {curr_time}"); starter.record() 278 | 279 | # sigmoid activation for rgb 280 | color = torch.sigmoid(h) 281 | 282 | return sigma, color, ambient 283 | 284 | 285 | def density(self, x, enc_a, e=None): 286 | # x: [N, 3], in [-bound, bound] 287 | 288 | if enc_a is None: 289 | ambient = torch.zeros_like(x[:, :self.ambient_dim]) 290 | enc_x = self.encoder(x, bound=self.bound) 291 | enc_w = self.encoder_ambient(ambient, bound=1) 292 | else: 293 | 294 | enc_a = enc_a.repeat(x.shape[0], 1) 295 | enc_x = self.encoder(x, bound=self.bound) 296 | 297 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"enocoder_deform = {curr_time}"); starter.record() 298 | 299 | # ambient 300 | ambient = torch.cat([enc_x, enc_a], dim=1) 301 | ambient = self.ambient_net(ambient).float() 302 | ambient = torch.tanh(ambient) # map to [-1, 1] 303 | 304 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"de-an net = {curr_time}"); starter.record() 305 | 306 | # sigma 307 | enc_w = self.encoder_ambient(ambient, bound=1) 308 | 309 | # ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f"encoder = {curr_time}"); starter.record() 310 | 311 | if e is not None: 312 | h = torch.cat([enc_x, enc_w, e.repeat(x.shape[0], 1)], dim=-1) 313 | else: 314 | h = torch.cat([enc_x, enc_w], dim=-1) 315 | 316 | h = self.sigma_net(h) 317 | 318 | sigma = trunc_exp(h[..., 0]) 319 | geo_feat = h[..., 1:] 320 | 321 | return { 322 | 'sigma': sigma, 323 | 'geo_feat': geo_feat, 324 | } 325 | 326 | 327 | # optimizer utils 328 | def get_params(self, lr, lr_net, wd=0): 329 | 330 | # ONLY train torso 331 | if self.torso: 332 | params = [ 333 | {'params': self.torso_encoder.parameters(), 'lr': lr}, 334 | {'params': self.torso_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, 335 | {'params': self.torso_deform_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, 336 | ] 337 | 338 | if self.individual_dim_torso > 0: 339 | params.append({'params': self.individual_codes_torso, 'lr': lr_net, 'weight_decay': wd}) 340 | 341 | return params 342 | 343 | params = [ 344 | {'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, 345 | {'params': self.encoder.parameters(), 'lr': lr}, 346 | {'params': self.encoder_ambient.parameters(), 'lr': lr}, 347 | {'params': self.ambient_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, 348 | {'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, 349 | {'params': self.color_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, 350 | ] 351 | if self.att > 0: 352 | params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': wd}) 353 | if self.emb: 354 | params.append({'params': self.embedding.parameters(), 'lr': lr}) 355 | if self.individual_dim > 0: 356 | params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd}) 357 | if self.train_camera: 358 | params.append({'params': self.camera_dT, 'lr': 1e-5, 'weight_decay': 0}) 359 | params.append({'params': self.camera_dR, 'lr': 1e-5, 'weight_decay': 0}) 360 | 361 | return params -------------------------------------------------------------------------------- /raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * -------------------------------------------------------------------------------- /raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_raymarching_face', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'raymarching.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /raymarching/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | # '-lineinfo', # to debug illegal memory access 10 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 11 | ] 12 | 13 | if os.name == "posix": 14 | c_flags = ['-O3', '-std=c++14'] 15 | elif os.name == "nt": 16 | c_flags = ['/O2', '/std:c++17'] 17 | 18 | # find cl.exe 19 | def find_cl_path(): 20 | import glob 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | ''' 34 | Usage: 35 | 36 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) 37 | 38 | python setup.py install # build extensions and install (copy) to PATH. 39 | pip install . # ditto but better (e.g., dependency & metadata handling) 40 | 41 | python setup.py develop # build extensions and install (symbolic) to PATH. 42 | pip install -e . # ditto but better (e.g., dependency & metadata handling) 43 | 44 | ''' 45 | setup( 46 | name='raymarching_face', # package name, import this to use python API 47 | ext_modules=[ 48 | CUDAExtension( 49 | name='_raymarching_face', # extension name, import this to use CUDA API 50 | sources=[os.path.join(_src_path, 'src', f) for f in [ 51 | 'raymarching.cu', 52 | 'bindings.cpp', 53 | ]], 54 | extra_compile_args={ 55 | 'cxx': c_flags, 56 | 'nvcc': nvcc_flags, 57 | } 58 | ), 59 | ], 60 | cmdclass={ 61 | 'build_ext': BuildExtension, 62 | } 63 | ) -------------------------------------------------------------------------------- /raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("packbits", &packbits, "packbits (CUDA)"); 8 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 9 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); 10 | m.def("morton3D", &morton3D, "morton3D (CUDA)"); 11 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); 12 | m.def("morton3D_dilation", &morton3D_dilation, "morton3D_dilation (CUDA)"); 13 | // train 14 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 15 | m.def("march_rays_train_backward", &march_rays_train_backward, "march_rays_train_backward (CUDA)"); 16 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 17 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 18 | // infer 19 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 20 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); 21 | } -------------------------------------------------------------------------------- /raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); 9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); 10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); 11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); 12 | void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation); 13 | 14 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); 15 | void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d); 16 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image); 17 | void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient); 18 | 19 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); 20 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # RAD-NeRF: Real-time Neural Talking Portrait Synthesis 2 | 3 | This repository contains a PyTorch re-implementation of the paper: [Real-time Neural Radiance Talking Portrait Synthesis via Audio-spatial Decomposition](https://arxiv.org/abs/2211.12368). 4 | 5 | Colab notebook demonstration: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ZsC6J-eeaOFP43Oi8DuY_aMSNUlM0A_c?usp=sharing) 6 | 7 | ### [Project Page](https://ashawkey.github.io/radnerf/) | [Arxiv](https://arxiv.org/abs/2211.12368) | [Data](https://drive.google.com/drive/folders/14LfowIkNdjRAD-0ezJ3JENWwY9_ytcXR?usp=sharing) 8 | 9 | A GUI for easy visualization: 10 | 11 | https://user-images.githubusercontent.com/25863658/201629660-7ada624b-8602-4cfe-96b3-61e3d465ced6.mp4 12 | 13 | # Install 14 | 15 | Tested on Ubuntu 22.04, Pytorch 1.12 and CUDA 11.6. 16 | 17 | ```bash 18 | git clone https://github.com/ashawkey/RAD-NeRF.git 19 | cd RAD-NeRF 20 | ``` 21 | 22 | ### Install dependency 23 | ```bash 24 | # for ubuntu, portaudio is needed for pyaudio to work. 25 | sudo apt install portaudio19-dev 26 | 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ### Build extension (optional) 31 | By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime. 32 | However, this may be inconvenient sometimes. 33 | Therefore, we also provide the `setup.py` to build each extension: 34 | ```bash 35 | # install all extension modules 36 | bash scripts/install_ext.sh 37 | ``` 38 | 39 | # Data pre-processing 40 | 41 | ### Preparation: 42 | 43 | ```bash 44 | ## install pytorch3d 45 | pip install "git+https://github.com/facebookresearch/pytorch3d.git" 46 | 47 | ## prepare face-parsing model 48 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_parsing/79999_iter.pth?raw=true -O data_utils/face_parsing/79999_iter.pth 49 | 50 | ## prepare basel face model 51 | # 1. download `01_MorphableModel.mat` from https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads and put it under `data_utils/face_tracking/3DMM/` 52 | # 2. download other necessary files from AD-NeRF's repository: 53 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/exp_info.npy?raw=true -O data_utils/face_tracking/3DMM/exp_info.npy 54 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/keys_info.npy?raw=true -O data_utils/face_tracking/3DMM/keys_info.npy 55 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/sub_mesh.obj?raw=true -O data_utils/face_tracking/3DMM/sub_mesh.obj 56 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/topology_info.npy?raw=true -O data_utils/face_tracking/3DMM/topology_info.npy 57 | # 3. run convert_BFM.py 58 | cd data_utils/face_tracking 59 | python convert_BFM.py 60 | cd ../.. 61 | 62 | ## prepare ASR model 63 | # if you want to use DeepSpeech as AD-NeRF, you should install tensorflow 1.15 manually. 64 | # else, we also support Wav2Vec in PyTorch. 65 | ``` 66 | 67 | ### Pre-processing Custom Training Video 68 | * Put training video under `data//.mp4`. 69 | 70 | The video **must be 25FPS, with all frames containing the talking person**. 71 | The resolution should be about 512x512, and duration about 1-5min. 72 | ```bash 73 | # an example training video from AD-NeRF 74 | mkdir -p data/obama 75 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/dataset/vids/Obama.mp4?raw=true -O data/obama/obama.mp4 76 | ``` 77 | 78 | * Run script (may take hours dependending on the video length) 79 | ```bash 80 | # run all steps 81 | python data_utils/process.py data//.mp4 82 | 83 | # if you want to run a specific step 84 | python data_utils/process.py data//.mp4 --task 1 # extract audio wave 85 | ``` 86 | 87 | * File structure after finishing all steps: 88 | ```bash 89 | ./data/ 90 | ├──.mp4 # original video 91 | ├──ori_imgs # original images from video 92 | │ ├──0.jpg 93 | │ ├──0.lms # 2D landmarks 94 | │ ├──... 95 | ├──gt_imgs # ground truth images (static background) 96 | │ ├──0.jpg 97 | │ ├──... 98 | ├──parsing # semantic segmentation 99 | │ ├──0.png 100 | │ ├──... 101 | ├──torso_imgs # inpainted torso images 102 | │ ├──0.png 103 | │ ├──... 104 | ├──aud.wav # original audio 105 | ├──aud_eo.npy # audio features (wav2vec) 106 | ├──aud.npy # audio features (deepspeech) 107 | ├──bc.jpg # default background 108 | ├──track_params.pt # raw head tracking results 109 | ├──transforms_train.json # head poses (train split) 110 | ├──transforms_val.json # head poses (test split) 111 | ``` 112 | 113 | # Usage 114 | 115 | ### Quick Start 116 | 117 | We provide some pretrained models [here](https://drive.google.com/drive/folders/14LfowIkNdjRAD-0ezJ3JENWwY9_ytcXR?usp=sharing) for quick testing on arbitrary audio. 118 | 119 | * Download a pretrained model. 120 | For example, we download `obama_eo.pth` to `./pretrained/obama_eo.pth` 121 | 122 | * Download a pose sequence file. 123 | For example, we download `obama.json` to `./data/obama.json` 124 | 125 | * Prepare your audio as `.wav`, and extract audio features. 126 | ```bash 127 | # if model is `_eo.pth`, it uses wav2vec features 128 | python nerf/asr.py --wav data/.wav --save_feats # save to data/_eo.npy 129 | 130 | # if model is `.pth`, it uses deepspeech features 131 | python data_utils/deepspeech_features/extract_ds_features.py --input data/.wav # save to data/.npy 132 | ``` 133 | You can download pre-processed audio features too. 134 | For example, we download `intro_eo.npy` to `./data/intro_eo.npy`. 135 | 136 | * Run inference: 137 | It takes about 2GB GPU memory to run inference at 40FPS (measured on a V100). 138 | ```bash 139 | # save video to trail_obama/results/*.mp4 140 | # if model is `.pth`, should append `--asr_model deepspeech` and use `--aud intro.npy` instead. 141 | python test.py --pose data/obama.json --ckpt pretrained/obama_eo.pth --aud data/intro_eo.npy --workspace trial_obama/ -O --torso 142 | 143 | # provide a background image (default is white) 144 | python test.py --pose data/obama.json --ckpt pretrained/obama_eo.pth --aud data/intro_eo.npy --workspace trial_obama/ -O --torso --bg_img data/bg.jpg 145 | 146 | # test with GUI 147 | python test.py --pose data/obama.json --ckpt pretrained/obama_eo.pth --aud data/intro_eo.npy --workspace trial_obama/ -O --torso --bg_img data/bg.jpg --gui 148 | ``` 149 | 150 | ### Detailed Usage 151 | 152 | First time running will take some time to compile the CUDA extensions. 153 | 154 | ```bash 155 | # train (head) 156 | # by default, we load data from disk on the fly. 157 | # we can also preload all data to CPU/GPU for faster training, but this is very memory-hungry for large datasets. 158 | # `--preload 0`: load from disk (default, slower). 159 | # `--preload 1`: load to CPU, requires ~70G CPU memory (slightly slower) 160 | # `--preload 2`: load to GPU, requires ~24G GPU memory (fast) 161 | python main.py data/obama/ --workspace trial_obama/ -O --iters 200000 162 | 163 | # train (finetune lips for another 50000 steps, run after the above command!) 164 | python main.py data/obama/ --workspace trial_obama/ -O --iters 250000 --finetune_lips 165 | 166 | # train (torso) 167 | # .pth should be the latest checkpoint in trial_obama 168 | python main.py data/obama/ --workspace trial_obama_torso/ -O --torso --head_ckpt .pth --iters 200000 169 | 170 | # test on the test split 171 | python main.py data/obama/ --workspace trial_obama/ -O --test # use head checkpoint, will load GT torso 172 | python main.py data/obama/ --workspace trial_obama_torso/ -O --torso --test 173 | 174 | # test with GUI 175 | python main.py data/obama/ --workspace trial_obama_torso/ -O --torso --test --gui 176 | 177 | # test with GUI (load speech recognition model for real-time application) 178 | python main.py data/obama/ --workspace trial_obama_torso/ -O --torso --test --gui --asr 179 | 180 | # test with specific audio & pose sequence 181 | # --test_train: use train split for testing 182 | # --data_range: use this range's pose & eye sequence (if shorter than audio, automatically mirror and repeat) 183 | python main.py data/obama/ --workspace trial_obama_torso/ -O --torso --test --test_train --data_range 0 100 --aud data/intro_eo.npy 184 | ``` 185 | 186 | check the `scripts` directory for more provided examples. 187 | 188 | 189 | # Acknowledgement 190 | 191 | * The data pre-processing part is adapted from [AD-NeRF](https://github.com/YudongGuo/AD-NeRF). 192 | * The NeRF framework is based on [torch-ngp](https://github.com/ashawkey/torch-ngp). 193 | * The GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui). 194 | 195 | # Citation 196 | 197 | ``` 198 | @article{tang2022radnerf, 199 | title={Real-time Neural Radiance Talking Portrait Synthesis via Audio-spatial Decomposition}, 200 | author={Tang, Jiaxiang and Wang, Kaisiyuan and Zhou, Hang and Chen, Xiaokang and He, Dongliang and Hu, Tianshu and Liu, Jingtuo and Zeng, Gang and Wang, Jingdong}, 201 | journal={arXiv preprint arXiv:2211.12368}, 202 | year={2022} 203 | } 204 | ``` 205 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch-ema 2 | ninja 3 | trimesh 4 | opencv-python 5 | tensorboardX 6 | torch 7 | numpy 8 | pandas 9 | tqdm 10 | matplotlib 11 | PyMCubes 12 | rich 13 | dearpygui 14 | packaging 15 | scipy 16 | 17 | face_alignment 18 | python_speech_features 19 | numba 20 | resampy 21 | transformers 22 | pyaudio 23 | soundfile 24 | einops 25 | configargparse 26 | 27 | lpips 28 | imageio-ffmpeg 29 | -------------------------------------------------------------------------------- /scripts/install_ext.sh: -------------------------------------------------------------------------------- 1 | pip install ./freqencoder 2 | pip install ./shencoder 3 | pip install ./gridencoder 4 | pip install ./raymarching -------------------------------------------------------------------------------- /scripts/test_pretrained.sh: -------------------------------------------------------------------------------- 1 | # test obama eo 2 | python test.py \ 3 | --pose data/obama.json \ 4 | --ckpt pretrained/obama_eo.pth \ 5 | --aud data/intro_eo.npy \ 6 | --workspace trial_test \ 7 | --bg_img data/bg.jpg \ 8 | -O --torso --data_range 0 100 --preload 2 9 | 10 | # merge audio with video 11 | ffmpeg -y -i trial_test/results/ngp_ep0028.mp4 -i data/intro.wav -c:v copy -c:a aac obama_eo_intro.mp4 12 | 13 | # # test obama ds 14 | # python test.py \ 15 | # --pose data/obama.json \ 16 | # --ckpt pretrained/obama.pth \ 17 | # --aud data/intro.npy \ 18 | # --workspace trial_test \ 19 | # --bg_img data/bg.jpg \ 20 | # -O --torso --data_range 0 100 --asr_model deepspeech 21 | 22 | # # merge audio with video 23 | # ffmpeg -y -i trial_test/results/ngp_ep0056.mp4 -i data/intro.wav -c:v copy -c:a aac obama_intro.mp4 -------------------------------------------------------------------------------- /scripts/test_streaming.sh: -------------------------------------------------------------------------------- 1 | # end-to-end test with audio streaming 2 | python test.py \ 3 | --pose data/obama/transforms_train.json \ 4 | --ckpt trial_obama_eo_torso/checkpoints/ngp.pth \ 5 | --aud data/intro_eo.npy \ 6 | --workspace trial_test \ 7 | --bg_img data/obama/bc.jpg \ 8 | -l 10 -m 10 -r 10 \ 9 | -O --torso --data_range 0 100 --preload 2 --gui --asr -------------------------------------------------------------------------------- /scripts/train_obama_ds.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # train 4 | CUDA_VISIBLE_DEVICES=1 python main.py data/obama/ --workspace trial_obama_ds/ -O --iters 200000 --asr_model deepspeech 5 | CUDA_VISIBLE_DEVICES=1 python main.py data/obama/ --workspace trial_obama_ds/ -O --finetune_lips --iters 250000 --asr_model deepspeech 6 | 7 | CUDA_VISIBLE_DEVICES=1 python main.py data/obama/ --workspace trial_obama_ds_torso/ -O --torso --iters 200000 --head_ckpt trial_obama_ds/checkpoints/ngp_ep0035.pth --asr_model deepspeech 8 | 9 | # test 10 | CUDA_VISIBLE_DEVICES=1 python main.py data/obama/ --workspace trial_obama_ds_torso/ -O --torso --test --asr_model deepspeech -------------------------------------------------------------------------------- /scripts/train_obama_eo.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # train 4 | CUDA_VISIBLE_DEVICES=1 python main.py data/obama/ --workspace trial_obama_eo/ -O --iters 200000 5 | CUDA_VISIBLE_DEVICES=1 python main.py data/obama/ --workspace trial_obama_eo/ -O --finetune_lips --iters 250000 6 | 7 | CUDA_VISIBLE_DEVICES=1 python main.py data/obama/ --workspace trial_obama_eo_torso/ -O --torso --iters 200000 --head_ckpt trial_obama_eo/checkpoints/ngp_ep0035.pth 8 | 9 | # test 10 | CUDA_VISIBLE_DEVICES=1 python main.py data/obama/ --workspace trial_obama_eo_torso/ -O --torso --test -------------------------------------------------------------------------------- /shencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_harmonics import SHEncoder -------------------------------------------------------------------------------- /shencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_sh_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'shencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /shencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='shencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_shencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'shencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /shencoder/sphere_harmonics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _shencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | class _sh_encoder(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 17 | def forward(ctx, inputs, degree, calc_grad_inputs=False): 18 | # inputs: [B, input_dim], float in [-1, 1] 19 | # RETURN: [B, F], float 20 | 21 | inputs = inputs.contiguous() 22 | B, input_dim = inputs.shape # batch size, coord dim 23 | output_dim = degree ** 2 24 | 25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 26 | 27 | if calc_grad_inputs: 28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) 29 | else: 30 | dy_dx = None 31 | 32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) 33 | 34 | ctx.save_for_backward(inputs, dy_dx) 35 | ctx.dims = [B, input_dim, degree] 36 | 37 | return outputs 38 | 39 | @staticmethod 40 | #@once_differentiable 41 | @custom_bwd 42 | def backward(ctx, grad): 43 | # grad: [B, C * C] 44 | 45 | inputs, dy_dx = ctx.saved_tensors 46 | 47 | if dy_dx is not None: 48 | grad = grad.contiguous() 49 | B, input_dim, degree = ctx.dims 50 | grad_inputs = torch.zeros_like(inputs) 51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) 52 | return grad_inputs, None, None 53 | else: 54 | return None, None, None 55 | 56 | 57 | 58 | sh_encode = _sh_encoder.apply 59 | 60 | 61 | class SHEncoder(nn.Module): 62 | def __init__(self, input_dim=3, degree=4): 63 | super().__init__() 64 | 65 | self.input_dim = input_dim # coord dims, must be 3 66 | self.degree = degree # 0 ~ 4 67 | self.output_dim = degree ** 2 68 | 69 | assert self.input_dim == 3, "SH encoder only support input dim == 3" 70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" 71 | 72 | def __repr__(self): 73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" 74 | 75 | def forward(self, inputs, size=1): 76 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 77 | # return: [..., degree^2] 78 | 79 | inputs = inputs / size # [-1, 1] 80 | 81 | prefix_shape = list(inputs.shape[:-1]) 82 | inputs = inputs.reshape(-1, self.input_dim) 83 | 84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) 85 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 86 | 87 | return outputs -------------------------------------------------------------------------------- /shencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "shencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); 7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /shencoder/src/shencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // inputs: [B, D], float, in [-1, 1] 7 | // outputs: [B, F], float 8 | 9 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); 10 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from nerf.provider import NeRFDataset_Test 5 | from nerf.gui import NeRFGUI 6 | from nerf.utils import * 7 | 8 | # torch.autograd.set_detect_anomaly(True) 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--pose', type=str, help="transforms.json, pose source") 14 | parser.add_argument('--aud', type=str, help="aud.npy, audio source") 15 | parser.add_argument('--bg_img', type=str, default='white', help="bg.jpg, background image source") 16 | 17 | parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye") 18 | # parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)") 19 | # parser.add_argument('--test_train', action='store_true', help="test mode (load model and train dataset)") 20 | parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use") 21 | parser.add_argument('--workspace', type=str, default='workspace') 22 | parser.add_argument('--seed', type=int, default=0) 23 | 24 | ### training options 25 | # parser.add_argument('--iters', type=int, default=200000, help="training iters") 26 | # parser.add_argument('--lr', type=float, default=5e-3, help="initial learning rate") 27 | # parser.add_argument('--lr_net', type=float, default=5e-4, help="initial learning rate") 28 | parser.add_argument('--ckpt', type=str, default='latest') 29 | parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") 30 | parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") 31 | parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") 32 | parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") 33 | parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") 34 | parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") 35 | parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") 36 | 37 | 38 | ### network backbone options 39 | parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") 40 | 41 | parser.add_argument('--lambda_amb', type=float, default=0.1, help="lambda for ambient loss") 42 | 43 | parser.add_argument('--fbg', action='store_true', help="frame-wise bg") 44 | parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") 45 | parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") 46 | parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") 47 | 48 | parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") 49 | 50 | ### dataset options 51 | parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") 52 | # parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory") 53 | # (the default value is for the fox dataset) 54 | parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") 55 | parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") 56 | parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") 57 | parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") 58 | parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") 59 | parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") 60 | parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") 61 | parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") 62 | 63 | parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") 64 | parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...") 65 | 66 | parser.add_argument('--torso', action='store_true', help="fix head and train torso") 67 | parser.add_argument('--head_ckpt', type=str, default='', help="head model") 68 | 69 | ### GUI options 70 | parser.add_argument('--gui', action='store_true', help="start a GUI") 71 | parser.add_argument('--W', type=int, default=450, help="GUI width") 72 | parser.add_argument('--H', type=int, default=450, help="GUI height") 73 | parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center") 74 | parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy") 75 | parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") 76 | 77 | ### else 78 | parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") 79 | parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") 80 | 81 | parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") 82 | parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") 83 | 84 | parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") 85 | 86 | parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension") 87 | parser.add_argument('--part', action='store_true', help="use partial training data (1/10)") 88 | parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") 89 | 90 | parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") 91 | parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") 92 | parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") 93 | 94 | # asr 95 | parser.add_argument('--asr', action='store_true', help="load asr for real-time app") 96 | parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") 97 | parser.add_argument('--asr_play', action='store_true', help="play out the audio") 98 | 99 | parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') 100 | # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') 101 | 102 | parser.add_argument('--asr_save_feats', action='store_true') 103 | # audio FPS 104 | parser.add_argument('--fps', type=int, default=50) 105 | # sliding window left-middle-right length (unit: 20ms) 106 | parser.add_argument('-l', type=int, default=10) 107 | parser.add_argument('-m', type=int, default=50) 108 | parser.add_argument('-r', type=int, default=10) 109 | 110 | opt = parser.parse_args() 111 | 112 | # assert test mode 113 | opt.test = True 114 | opt.test_train = False 115 | 116 | # explicit smoothing 117 | opt.smooth_path = True 118 | opt.smooth_eye = True 119 | opt.smooth_lips = True 120 | 121 | assert opt.pose != '', 'Must provide a pose source' 122 | assert opt.aud != '', 'Must provide an audio source' 123 | 124 | if opt.O: 125 | opt.fp16 = True 126 | opt.exp_eye = True 127 | 128 | opt.cuda_ray = True 129 | # assert opt.cuda_ray, "Only support CUDA ray mode." 130 | 131 | from nerf.network import NeRFNetwork 132 | 133 | print(opt) 134 | 135 | seed_everything(opt.seed) 136 | 137 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 138 | 139 | model = NeRFNetwork(opt) 140 | 141 | # print(model) 142 | 143 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, fp16=opt.fp16, metrics=[], use_checkpoint=opt.ckpt) 144 | 145 | test_loader = NeRFDataset_Test(opt, device=device).dataloader() 146 | 147 | # temp fix: for update_extra_states 148 | model.aud_features = test_loader._data.auds 149 | model.eye_areas = test_loader._data.eye_area 150 | 151 | if opt.gui: 152 | # we still need test_loader to provide audio features for testing. 153 | with NeRFGUI(opt, trainer, test_loader) as gui: 154 | gui.render() 155 | 156 | else: 157 | 158 | ### test and save video (fast) 159 | trainer.test(test_loader) 160 | --------------------------------------------------------------------------------