├── .gitignore ├── README.md ├── __init_paths.py ├── audio.py ├── hparams.py ├── inference.py ├── misc ├── StyleSync.png └── StyleSync0.png ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── requirements.txt ├── stylesync_model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | ckpts/* 3 | results/* 4 | _README.md 5 | demo/ 6 | mask*.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### This a PyTorch-implemented repo. Find the Paddle-based version [here](https://github.com/guanjz20/StyleSync). 2 | 3 | # StyleSync: High-Fidelity Generalized and Personalized Lip Sync in Style-based Generator (CVPR 2023) 4 | 5 | Jiazhi Guan*, Zhanwang Zhang*, [Hang Zhou](https://hangz-nju-cuhk.github.io/)†, [Tianshu Hu](https://scholar.google.com/citations?user=BIixVT0AAAAJ)†, [Kaisiyuan Wang](https://scholar.google.com/citations?user=2Pedf3EAAAAJ), [Dongliang He](https://scholar.google.com/citations?user=ui6DYGoAAAAJ), Haocheng Feng, [Jingtuo Liu](https://scholar.google.com/citations?user=tVV3jmcAAAAJ), [Errui Ding](https://scholar.google.com/citations?user=1wzEtxcAAAAJ), [Ziwei Liu](https://liuziwei7.github.io/), [Jingdong Wang](https://jingdongwang2017.github.io/) 6 | 7 | 8 | 9 | ### [Project](https://hangz-nju-cuhk.github.io/projects/StyleSync) | [Paper](https://arxiv.org/pdf/2305.05445.pdf) | [Video](https://www.youtube.com/watch?v=uuBglL2KGFc) | [Demo](https://www.youtube.com/watch?v=yAPDl2dVonY) 10 | 11 | We propose **StyleSync**, an effective framework that enables high-fidelity lip synchronization. We identify that a style-based generator would sufficiently enable such a charming property on both one-shot and few-shot scenarios. 12 | 13 | ## Code 14 | Inference script and model code have been released. 15 | 16 | ## Run Generation 17 | The decision to suspend the release of the model weights has been made by our team based on various considerations. At present, I also do not have a specific expected date for its release. However, if you are interested in utilizing our demo for academic purposes, such as conducting a comparison in your paper, please don't hesitate to contact me by [guanjz20 at mails dot tsinghua dot edu dot cn]. 18 | 19 | 20 | ## Citation 21 | ``` 22 | @inproceedings{guan2023stylesync, 23 | title = {StyleSync: High-Fidelity Generalized and Personalized Lip Sync in Style-based Generator}, 24 | author = {Guan, Jiazhi and Zhang, Zhanwang and Zhou, Hang and HU, Tianshu and Wang, Kaisiyuan and He, Dongliang and Feng, Haocheng and Liu, Jingtuo and Ding, Errui and Liu, Ziwei and Wang, Jingdong}, 25 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 26 | year = {2023} 27 | } 28 | ``` -------------------------------------------------------------------------------- /__init_paths.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | 4 | def add_path(path): 5 | if path not in sys.path: 6 | sys.path.insert(0, path) 7 | 8 | this_dir = osp.dirname(__file__) 9 | 10 | WORK_ROOT = '/home/guanjz' 11 | OUT_ROOT = './result' 12 | LOG_ROOT = './log' 13 | TEMP_ROOT ='./temp' 14 | DATA_ROOT = '' 15 | VIDEO_ROOT = '' -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | from scipy import signal 4 | from scipy.io import wavfile 5 | from hparams import hparams as hp 6 | 7 | 8 | def load_wav(path, sr): 9 | return librosa.core.load(path, sr=sr)[0] 10 | 11 | 12 | def save_wav(wav, path, sr): 13 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 14 | wavfile.write(path, sr, wav.astype(np.int16)) 15 | 16 | 17 | def save_wavenet_wav(wav, path, sr): 18 | librosa.output.write_wav(path, wav, sr=sr) 19 | 20 | 21 | def preemphasis(wav, k, preemphasize=True): 22 | if preemphasize: 23 | return signal.lfilter([1, -k], [1], wav) 24 | return wav 25 | 26 | 27 | def inv_preemphasis(wav, k, inv_preemphasize=True): 28 | if inv_preemphasize: 29 | return signal.lfilter([1], [1, -k], wav) 30 | return wav 31 | 32 | 33 | def get_hop_size(): 34 | hop_size = hp.hop_size 35 | if hop_size is None: 36 | assert hp.frame_shift_ms is not None 37 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 38 | return hop_size 39 | 40 | 41 | def linearspectrogram(wav): 42 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 43 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 44 | 45 | if hp.signal_normalization: 46 | return _normalize(S) 47 | return S 48 | 49 | 50 | def melspectrogram(wav): 51 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 52 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 53 | 54 | if hp.signal_normalization: 55 | return _normalize(S) 56 | return S 57 | 58 | 59 | def _lws_processor(): 60 | import lws 61 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 62 | 63 | 64 | def _stft(y): 65 | if hp.use_lws: 66 | return _lws_processor(hp).stft(y).T 67 | else: 68 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 69 | 70 | 71 | def num_frames(length, fsize, fshift): 72 | """Compute number of time frames of spectrogram 73 | """ 74 | pad = (fsize - fshift) 75 | if length % fshift == 0: 76 | M = (length + pad * 2 - fsize) // fshift + 1 77 | else: 78 | M = (length + pad * 2 - fsize) // fshift + 2 79 | return M 80 | 81 | 82 | def pad_lr(x, fsize, fshift): 83 | """Compute left and right padding 84 | """ 85 | M = num_frames(len(x), fsize, fshift) 86 | pad = (fsize - fshift) 87 | T = len(x) + 2 * pad 88 | r = (M - 1) * fshift + fsize - T 89 | return pad, pad + r 90 | 91 | 92 | def librosa_pad_lr(x, fsize, fshift): 93 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 94 | 95 | 96 | _mel_basis = None 97 | 98 | 99 | def _linear_to_mel(spectogram): 100 | global _mel_basis 101 | if _mel_basis is None: 102 | _mel_basis = _build_mel_basis() 103 | return np.dot(_mel_basis, spectogram) 104 | 105 | 106 | def _build_mel_basis(): 107 | assert hp.fmax <= hp.sample_rate // 2 108 | return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin, fmax=hp.fmax) 109 | 110 | 111 | def _amp_to_db(x): 112 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 113 | return 20 * np.log10(np.maximum(min_level, x)) 114 | 115 | 116 | def _db_to_amp(x): 117 | return np.power(10.0, (x) * 0.05) 118 | 119 | 120 | def _normalize(S): 121 | if hp.allow_clipping_in_normalization: 122 | if hp.symmetric_mels: 123 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, -hp.max_abs_value, 124 | hp.max_abs_value) 125 | else: 126 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 127 | 128 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 129 | if hp.symmetric_mels: 130 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 131 | else: 132 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 133 | 134 | 135 | def _denormalize(D): 136 | if hp.allow_clipping_in_normalization: 137 | if hp.symmetric_mels: 138 | return (((np.clip(D, -hp.max_abs_value, hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + 139 | hp.min_level_db) 140 | else: 141 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 142 | 143 | if hp.symmetric_mels: 144 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 145 | else: 146 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 147 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | 5 | class HParams: 6 | 7 | def __init__(self, **kwargs): 8 | self.data = {} 9 | 10 | for key, value in kwargs.items(): 11 | self.data[key] = value 12 | 13 | def __getattr__(self, key): 14 | if key not in self.data: 15 | raise AttributeError("'HParams' object has no attribute %s" % key) 16 | return self.data[key] 17 | 18 | def set_hparam(self, key, value): 19 | self.data[key] = value 20 | 21 | 22 | hparams = HParams( 23 | num_mels=80, 24 | rescale=True, 25 | rescaling_max=0.9, 26 | use_lws=False, 27 | n_fft=800, 28 | hop_size=200, 29 | win_size=800, 30 | sample_rate=16000, 31 | frame_shift_ms=None, 32 | signal_normalization=True, 33 | allow_clipping_in_normalization=True, 34 | symmetric_mels=True, 35 | max_abs_value=4., 36 | preemphasize=True, 37 | preemphasis=0.97, 38 | min_level_db=-100, 39 | ref_level_db=20, 40 | fmin=55, 41 | fmax=7600, 42 | img_size=256, 43 | fps=25, 44 | initial_learning_rate=1e-4, 45 | nepochs=200000000000000000, 46 | num_workers=4, 47 | checkpoint_interval=10000, 48 | eval_interval=10000, 49 | save_optimizer_state=False, 50 | syncnet_lr=1e-4, #1e-4 51 | syncnet_eval_interval=4000, 52 | syncnet_checkpoint_interval=4000, 53 | disc_initial_learning_rate=1e-4, 54 | syncnet_batch_size=256, 55 | batch_size=16, 56 | ) 57 | 58 | 59 | def hparams_debug_string(): 60 | values = hparams.values() 61 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] 62 | return "Hyperparameters:\n" + "\n".join(hp) 63 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | import random 5 | import argparse 6 | import torch 7 | import tempfile 8 | import subprocess 9 | import face_alignment 10 | import numpy as np 11 | import os.path as osp 12 | from tqdm import tqdm 13 | from glob import glob 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--checkpoint_path', type=str, required=True) 17 | parser.add_argument('--face', type=str, required=True) 18 | parser.add_argument('--audio', type=str, required=True) 19 | parser.add_argument('--max_mel_chunks', type=int, default=None) 20 | parser.add_argument('--save_root', type=str, default='./results') 21 | parser.add_argument('--save_name', type=str, default=None) 22 | parser.add_argument('--fps', type=float, default=25) 23 | parser.add_argument('--img_size', type=int, default=256) 24 | parser.add_argument('--mask', default="mask3.jpg", type=str) 25 | # run 26 | parser.add_argument('--gpus', type=str, default='0') 27 | parser.add_argument('--device', type=str, default='cuda') 28 | parser.add_argument('--save_crop', action='store_true') 29 | parser.add_argument('--tag', type=str, default='StyleSync Inference...') 30 | 31 | args = parser.parse_args() 32 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 33 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 34 | 35 | import __init_paths 36 | import audio 37 | import utils 38 | 39 | 40 | def main(): 41 | # face 42 | temp_face_file = tempfile.NamedTemporaryFile(suffix=".mp4") 43 | if not os.path.isfile(args.face): 44 | fnames = list(glob(os.path.join(args.face, '*.jpg'))) 45 | sorted_fnames = sorted(fnames, key=lambda f: int(os.path.basename(f).split('.')[0])) 46 | full_frames = [cv2.imread(f) for f in sorted_fnames] 47 | fps = args.fps 48 | elif args.face.split('.')[-1] in ['jpg', 'png', 'jpeg']: 49 | full_frames = [cv2.imread(args.face)] 50 | fps = args.fps 51 | elif args.face.split('.')[-1] in ['mp4', 'mov', 'MOV', 'MP4', 'webm']: 52 | video_stream = cv2.VideoCapture(args.face) 53 | fps = video_stream.get(cv2.CAP_PROP_FPS) 54 | try: 55 | assert fps == args.fps 56 | except: 57 | print('Converting video to fps 25...') 58 | video_name = temp_face_file.name 59 | command = 'ffmpeg -loglevel panic -i {} -qscale 0 -strict -2 -r {} -y {}'.format(args.face, fps, video_name) 60 | subprocess.call(command, shell=True) 61 | video_stream = cv2.VideoCapture(video_name) 62 | print('Reading video frames...') 63 | full_frames = [] 64 | while 1: 65 | print('Reading {}...'.format(len(full_frames)), end='\r') 66 | still_reading, frame = video_stream.read() 67 | if not still_reading: 68 | video_stream.release() 69 | break 70 | full_frames.append(frame) 71 | if args.max_mel_chunks and len(full_frames) > args.max_mel_chunks + 10: 72 | video_stream.release() 73 | break 74 | 75 | # audio 76 | temp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav") 77 | if osp.basename(args.audio).split('.')[1] in ['wav', 'mp3']: 78 | wav_path = args.audio 79 | elif os.path.basename(args.audio).split('.')[1] in ['mp4', 'avi', 'MP4', 'AVI', 'MOV', 'mov', 'webm']: 80 | print('Extracting raw audio...') 81 | audio_name = temp_audio_file.name 82 | command = 'ffmpeg -i {} -loglevel error -y -f wav -acodec pcm_s16le -ar 16000 {}'.format(args.audio, audio_name) 83 | subprocess.call(command, shell=True) 84 | wav_path = audio_name 85 | 86 | # run 87 | with torch.no_grad(): 88 | print("Loading model...") 89 | model = utils.load_model(args.checkpoint_path) 90 | model = model.to(args.device) 91 | model.eval() 92 | 93 | save_name = args.save_name or '{}_{}.mp4'.format(osp.basename(args.face).split('.')[0], osp.basename(args.audio).split('.')[0]) 94 | save_path = os.path.join(args.save_root, save_name) 95 | os.makedirs(osp.dirname(save_path), exist_ok=True) 96 | print("=====>", save_path) 97 | infer_one(model, full_frames, wav_path, save_path) 98 | 99 | temp_face_file.close() 100 | temp_audio_file.close() 101 | 102 | 103 | def infer_one(model, imgs, wav_path, save_path): 104 | out_video_p = tempfile.NamedTemporaryFile(suffix=".avi") 105 | crop_out_video_p = tempfile.NamedTemporaryFile(suffix=".avi") 106 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=args.device) 107 | restorer = utils.AlignRestore() 108 | lmk_smoother = utils.laplacianSmooth() 109 | mel_chunks, _ = utils.read_wav(wav_path) 110 | mel_chunks = mel_chunks[:args.max_mel_chunks] 111 | img_mask = 1. - cv2.resize(cv2.imread(args.mask), (args.img_size, args.img_size), interpolation=cv2.INTER_AREA) / 255. 112 | img_idxs_org = list(range(len(imgs))) 113 | img_idxs_dst = [] 114 | while (len(img_idxs_dst) < len(mel_chunks) + 10): 115 | img_idxs_dst += img_idxs_org 116 | img_idxs_org = img_idxs_org[::-1] 117 | frame_h, frame_w, _ = imgs[0].shape 118 | out = cv2.VideoWriter(out_video_p.name, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frame_w, frame_h)) 119 | 120 | # img prep 121 | skip_begin = 0 122 | face_all = [] 123 | box_all = [] 124 | affine_matrix_all = [] 125 | face_crop_data_dict = {} 126 | print('Run face cropping...') 127 | for i, m in tqdm(enumerate(mel_chunks), total=len(mel_chunks)): 128 | img_idx = img_idxs_dst[i] 129 | if img_idx in face_crop_data_dict: 130 | _data = face_crop_data_dict[img_idx] 131 | affine_matrix_all.append(_data['affine_matrix']) 132 | box_all.append(_data['box']) 133 | face_all.append(_data['face']) 134 | continue 135 | img = imgs[img_idx].copy() 136 | try: 137 | re_lmks = fa.get_landmarks(img.copy()) 138 | points = lmk_smoother.smooth(re_lmks[0]) 139 | lmk3_ = np.zeros((3, 2)) 140 | lmk3_[0] = points[17:22].mean(0) 141 | lmk3_[1] = points[22:27].mean(0) 142 | lmk3_[2] = points[27:36].mean(0) 143 | except Exception as e: 144 | print('Face detection fail...\n[{}]'.format(e)) 145 | if len(affine_matrix_all) == 0: 146 | skip_begin += 1 147 | continue 148 | affine_matrix_all.append(affine_matrix_all[-1]) 149 | box_all.append(box_all[-1]) 150 | face_all.append(face_all[-1]) 151 | continue 152 | face, affine_matrix = restorer.align_warp_face(img.copy(), lmks3=lmk3_, smooth=True) 153 | box = [0, 0, face.shape[1], face.shape[0]] 154 | if i == 0 and args.save_crop: 155 | out_crop = cv2.VideoWriter(crop_out_video_p.name, cv2.VideoWriter_fourcc(*'DIVX'), 25, (face.shape[1], face.shape[0])) 156 | face = cv2.resize(face, (args.img_size, args.img_size), interpolation=cv2.INTER_CUBIC) 157 | affine_matrix_all.append(affine_matrix) 158 | box_all.append(box) 159 | face_all.append(face) 160 | face_crop_data_dict[img_idx] = {'affine_matrix': affine_matrix, 'box': box, 'face': face} 161 | while len(face_all) < len(mel_chunks): 162 | assert skip_begin > 0 163 | affine_matrix_all += affine_matrix_all[::-1] 164 | box_all += box_all[::-1] 165 | face_all += face_all[::-1] 166 | 167 | print('Run generation...') 168 | for i, m in tqdm(enumerate(mel_chunks), total=len(mel_chunks)): 169 | img = imgs[img_idxs_dst[i]].copy() 170 | face = face_all[i].copy() 171 | box = box_all[i] 172 | affine_matrix = affine_matrix_all[i] 173 | face_masked = face.copy() * img_mask 174 | ref_face = face.copy() 175 | 176 | # infer 177 | ref_faces = ref_face[np.newaxis, :] 178 | face_masked = face_masked[np.newaxis, :] 179 | mel_batch = mel_chunks[i][np.newaxis, :] 180 | img_batch = np.concatenate((face_masked, ref_faces), axis=3) / 255. 181 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 182 | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(args.device) 183 | mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(args.device) 184 | pred = model(img_batch, mel_batch) 185 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. 186 | pred = pred.astype(np.uint8)[0] 187 | 188 | # save 189 | x1, y1, x2, y2 = box 190 | pred = cv2.resize(pred, (x2 - x1, y2 - y1), interpolation=cv2.INTER_CUBIC) 191 | if args.save_crop: 192 | out_crop.write(pred) 193 | out_img = restorer.restore_img(img, pred, affine_matrix) 194 | out.write(out_img) 195 | 196 | # write video 197 | out.release() 198 | command = 'ffmpeg -loglevel panic -y -i {} -i {} -vcodec libx264 -crf 12 -pix_fmt yuv420p -shortest {}'.format( 199 | wav_path, out_video_p.name, save_path) 200 | subprocess.call(command, shell=True) 201 | if args.save_crop: 202 | out_crop.release() 203 | command = 'ffmpeg -loglevel panic -y -i {} -i {} -vcodec libx264 -crf 12 -pix_fmt yuv420p -shortest {}'.format( 204 | wav_path, crop_out_video_p.name, save_path[:-4] + "_crop.mp4") 205 | subprocess.call(command, shell=True) 206 | 207 | out_video_p.close() 208 | crop_out_video_p.close() 209 | print('[DONE]') 210 | 211 | 212 | if __name__ == '__main__': 213 | main() 214 | -------------------------------------------------------------------------------- /misc/StyleSync.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanjz20/StyleSync_PyTorch/ecafd9016a204bf39a4137268f432ef778ff93d6/misc/StyleSync.png -------------------------------------------------------------------------------- /misc/StyleSync0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanjz20/StyleSync_PyTorch/ecafd9016a204bf39a4137268f432ef778ff93d6/misc/StyleSync0.png -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Function 8 | from torch.utils.cpp_extension import load, _import_module_from_library 9 | 10 | # if running GPEN without cuda, please comment line 11-19 11 | if platform.system() == 'Linux' and torch.cuda.is_available(): 12 | module_path = os.path.dirname(__file__) 13 | fused = load( 14 | 'fused', 15 | sources=[ 16 | os.path.join(module_path, 'fused_bias_act.cpp'), 17 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 18 | ], 19 | ) 20 | 21 | 22 | class FusedLeakyReLUFunctionBackward(Function): 23 | 24 | @staticmethod 25 | def forward(ctx, grad_output, out, negative_slope, scale): 26 | ctx.save_for_backward(out) 27 | ctx.negative_slope = negative_slope 28 | ctx.scale = scale 29 | 30 | empty = grad_output.new_empty(0) 31 | 32 | grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 33 | 34 | dim = [0] 35 | 36 | if grad_input.ndim > 2: 37 | dim += list(range(2, grad_input.ndim)) 38 | 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | return grad_input, grad_bias 42 | 43 | @staticmethod 44 | def backward(ctx, gradgrad_input, gradgrad_bias): 45 | out, = ctx.saved_tensors 46 | gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | 74 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5, device='cpu'): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | self.device = device 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale, self.device) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5, device='cpu'): 87 | if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu': 88 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 89 | else: 90 | return scale * F.leaky_relu(input + bias.view((1, -1) + (1,) * (len(input.shape) - 2)), negative_slope=negative_slope) 91 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load, _import_module_from_library 8 | 9 | # if running GPEN without cuda, please comment line 10-18 10 | if platform.system() == 'Linux' and torch.cuda.is_available(): 11 | module_path = os.path.dirname(__file__) 12 | upfirdn2d_op = load( 13 | 'upfirdn2d', 14 | sources=[ 15 | os.path.join(module_path, 'upfirdn2d.cpp'), 16 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 17 | ], 18 | ) 19 | 20 | 21 | class UpFirDn2dBackward(Function): 22 | 23 | @staticmethod 24 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 82 | 83 | return gradgrad_out, None, None, None, None, None, None, None, None 84 | 85 | 86 | class UpFirDn2d(Function): 87 | 88 | @staticmethod 89 | def forward(ctx, input, kernel, up, down, pad): 90 | up_x, up_y = up 91 | down_x, down_y = down 92 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 93 | 94 | kernel_h, kernel_w = kernel.shape 95 | batch, channel, in_h, in_w = input.shape 96 | ctx.in_size = input.shape 97 | 98 | input = input.reshape(-1, in_h, in_w, 1) 99 | 100 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 101 | 102 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 103 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 104 | ctx.out_size = (out_h, out_w) 105 | 106 | ctx.up = (up_x, up_y) 107 | ctx.down = (down_x, down_y) 108 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 109 | 110 | g_pad_x0 = kernel_w - pad_x0 - 1 111 | g_pad_y0 = kernel_h - pad_y0 - 1 112 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 113 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 114 | 115 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 116 | 117 | out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 118 | out = out.view(-1, channel, out_h, out_w) 119 | 120 | return out 121 | 122 | @staticmethod 123 | def backward(ctx, grad_output): 124 | kernel, grad_kernel = ctx.saved_tensors 125 | 126 | grad_input = UpFirDn2dBackward.apply( 127 | grad_output, 128 | kernel, 129 | grad_kernel, 130 | ctx.up, 131 | ctx.down, 132 | ctx.pad, 133 | ctx.g_pad, 134 | ctx.in_size, 135 | ctx.out_size, 136 | ) 137 | 138 | return grad_input, None, None, None, None 139 | 140 | 141 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0), device='cpu'): 142 | if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu': 143 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 144 | else: 145 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 151 | input = input.permute(0, 2, 3, 1) 152 | _, in_h, in_w, minor = input.shape 153 | kernel_h, kernel_w = kernel.shape 154 | out = input.view(-1, in_h, 1, in_w, 1, minor) 155 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 156 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 157 | 158 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 159 | out = out[ 160 | :, 161 | max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), 162 | max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), 163 | :, 164 | ] 165 | 166 | out = out.permute(0, 3, 1, 2) 167 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 168 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 169 | out = F.conv2d(out, w) 170 | out = out.reshape( 171 | -1, 172 | minor, 173 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 174 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 175 | ) 176 | return out[:, :, ::down_y, ::down_x] 177 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.9.2 2 | matplotlib==3.5.1 3 | munch==2.5.0 4 | ninja==1.10.2.4 5 | opencv-python==4.6.0.66 6 | opencv-python-headless==4.6.0.66 7 | scikit-image==0.19.2 8 | torch==1.12.1 9 | torchvision==0.13.1 10 | tqdm==4.64.0 11 | timm==0.6.13 12 | face-alignment==1.3.5 -------------------------------------------------------------------------------- /stylesync_model.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import re 3 | import math 4 | import random 5 | import itertools 6 | import logging 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | import utils 12 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 13 | 14 | _logger = logging.getLogger('model') 15 | 16 | 17 | def make_kernel(k): 18 | k = torch.tensor(k, dtype=torch.float32) 19 | 20 | if k.ndim == 1: 21 | k = k[None, :] * k[:, None] 22 | 23 | k /= k.sum() 24 | 25 | return k 26 | 27 | 28 | class PixelNorm(nn.Module): 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, input): 34 | return input * torch.rsqrt(torch.mean(input**2, dim=1, keepdim=True) + 1e-8) 35 | 36 | 37 | class Upsample(nn.Module): 38 | 39 | def __init__(self, kernel, factor=2, device='cpu'): 40 | super().__init__() 41 | 42 | self.factor = factor 43 | kernel = make_kernel(kernel) * (factor**2) 44 | self.register_buffer('kernel', kernel) 45 | 46 | p = kernel.shape[0] - factor 47 | 48 | pad0 = (p + 1) // 2 + factor - 1 49 | pad1 = p // 2 50 | 51 | self.pad = (pad0, pad1) 52 | self.device = device 53 | 54 | def forward(self, input): 55 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad, device=self.device) 56 | 57 | return out 58 | 59 | 60 | class Downsample(nn.Module): 61 | 62 | def __init__(self, kernel, factor=2, device='cpu'): 63 | super().__init__() 64 | 65 | self.factor = factor 66 | kernel = make_kernel(kernel) 67 | self.register_buffer('kernel', kernel) 68 | 69 | p = kernel.shape[0] - factor 70 | 71 | pad0 = (p + 1) // 2 72 | pad1 = p // 2 73 | 74 | self.pad = (pad0, pad1) 75 | self.device = device 76 | 77 | def forward(self, input): 78 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad, device=self.device) 79 | 80 | return out 81 | 82 | 83 | class Blur(nn.Module): 84 | 85 | def __init__(self, kernel, pad, upsample_factor=1, device='cpu'): 86 | super().__init__() 87 | 88 | kernel = make_kernel(kernel) 89 | 90 | if upsample_factor > 1: 91 | kernel = kernel * (upsample_factor**2) 92 | 93 | self.register_buffer('kernel', kernel) 94 | 95 | self.pad = pad 96 | self.device = device 97 | 98 | def forward(self, input): 99 | out = upfirdn2d(input, self.kernel, pad=self.pad, device=self.device) 100 | 101 | return out 102 | 103 | 104 | class EqualConv2d(nn.Module): 105 | 106 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): 107 | super().__init__() 108 | 109 | self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) 110 | self.scale = 1 / math.sqrt(in_channel * kernel_size**2) 111 | 112 | self.stride = stride 113 | self.padding = padding 114 | 115 | if bias: 116 | self.bias = nn.Parameter(torch.zeros(out_channel)) 117 | else: 118 | self.bias = None 119 | 120 | def forward(self, input): 121 | out = F.conv2d( 122 | input, 123 | self.weight * self.scale, 124 | bias=self.bias, 125 | stride=self.stride, 126 | padding=self.padding, 127 | ) 128 | 129 | return out 130 | 131 | def __repr__(self): 132 | return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 133 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})') 134 | 135 | 136 | class EqualLinear(nn.Module): 137 | 138 | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, device='cpu'): 139 | super().__init__() 140 | 141 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 142 | 143 | if bias: 144 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 145 | else: 146 | self.bias = None 147 | 148 | self.activation = activation 149 | self.device = device 150 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 151 | self.lr_mul = lr_mul 152 | 153 | def forward(self, input): 154 | if self.activation: 155 | out = F.linear(input, self.weight * self.scale) 156 | out = fused_leaky_relu(out, self.bias * self.lr_mul, device=self.device) 157 | else: 158 | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) 159 | 160 | return out 161 | 162 | def __repr__(self): 163 | return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') 164 | 165 | 166 | class ScaledLeakyReLU(nn.Module): 167 | 168 | def __init__(self, negative_slope=0.2): 169 | super().__init__() 170 | self.negative_slope = negative_slope 171 | 172 | def forward(self, input): 173 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 174 | 175 | return out * math.sqrt(2) 176 | 177 | 178 | class ModulatedConv2d(nn.Module): 179 | 180 | def __init__(self, 181 | in_channel, 182 | out_channel, 183 | kernel_size, 184 | style_dim, 185 | demodulate=True, 186 | upsample=False, 187 | downsample=False, 188 | blur_kernel=[1, 3, 3, 1], 189 | device='cpu'): 190 | super().__init__() 191 | 192 | self.eps = 1e-8 193 | self.kernel_size = kernel_size 194 | self.in_channel = in_channel 195 | self.out_channel = out_channel 196 | self.upsample = upsample 197 | self.downsample = downsample 198 | 199 | if upsample: 200 | factor = 2 201 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 202 | pad0 = (p + 1) // 2 + factor - 1 203 | pad1 = p // 2 + 1 204 | 205 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor, device=device) 206 | 207 | if downsample: 208 | factor = 2 209 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 210 | pad0 = (p + 1) // 2 211 | pad1 = p // 2 212 | 213 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), device=device) 214 | 215 | fan_in = in_channel * kernel_size**2 216 | self.scale = 1 / math.sqrt(fan_in) 217 | self.padding = kernel_size // 2 218 | 219 | self.weight = nn.Parameter(torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) 220 | 221 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 222 | 223 | self.demodulate = demodulate 224 | 225 | def __repr__(self): 226 | return (f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 227 | f'upsample={self.upsample}, downsample={self.downsample})') 228 | 229 | def forward(self, input, style): 230 | batch, in_channel, height, width = input.shape 231 | 232 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 233 | weight = self.scale * self.weight * style 234 | 235 | if self.demodulate: 236 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 237 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 238 | 239 | weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size) 240 | 241 | if self.upsample: 242 | input = input.view(1, batch * in_channel, height, width) 243 | weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size) 244 | weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size) 245 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 246 | _, _, height, width = out.shape 247 | out = out.view(batch, self.out_channel, height, width) 248 | out = self.blur(out) 249 | 250 | elif self.downsample: 251 | input = self.blur(input) 252 | _, _, height, width = input.shape 253 | input = input.view(1, batch * in_channel, height, width) 254 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 255 | _, _, height, width = out.shape 256 | out = out.view(batch, self.out_channel, height, width) 257 | 258 | else: 259 | input = input.view(1, batch * in_channel, height, width) 260 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 261 | _, _, height, width = out.shape 262 | out = out.view(batch, self.out_channel, height, width) 263 | 264 | return out 265 | 266 | 267 | class NoiseInjection(nn.Module): 268 | 269 | def __init__(self, isconcat=True): 270 | super().__init__() 271 | 272 | self.isconcat = isconcat 273 | self.weight = nn.Parameter(torch.zeros(1)) 274 | 275 | def forward(self, image, noise=None): 276 | if noise is None: 277 | batch, channel, height, width = image.shape 278 | noise = image.new_empty(batch, channel, height, width).normal_() 279 | 280 | if self.isconcat: 281 | return torch.cat((image, self.weight * noise), dim=1) 282 | else: 283 | return image + self.weight * noise 284 | 285 | 286 | class ConstantInput(nn.Module): 287 | 288 | def __init__(self, channel, size=4): 289 | super().__init__() 290 | 291 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 292 | 293 | def forward(self, input): 294 | batch = input.shape[0] 295 | out = self.input.repeat(batch, 1, 1, 1) 296 | 297 | return out 298 | 299 | 300 | class StyledConv(nn.Module): 301 | 302 | def __init__(self, 303 | in_channel, 304 | out_channel, 305 | kernel_size, 306 | style_dim, 307 | upsample=False, 308 | blur_kernel=[1, 3, 3, 1], 309 | demodulate=True, 310 | isconcat=True, 311 | device='cpu'): 312 | super().__init__() 313 | 314 | self.conv = ModulatedConv2d(in_channel, 315 | out_channel, 316 | kernel_size, 317 | style_dim, 318 | upsample=upsample, 319 | blur_kernel=blur_kernel, 320 | demodulate=demodulate, 321 | device=device) 322 | 323 | self.noise = NoiseInjection(isconcat) 324 | feat_multiplier = 2 if isconcat else 1 325 | self.activate = FusedLeakyReLU(out_channel * feat_multiplier, device=device) 326 | 327 | def forward(self, input, style, noise=None): 328 | out = self.conv(input, style) 329 | out = self.noise(out, noise=noise) 330 | out = self.activate(out) 331 | 332 | return out 333 | 334 | 335 | class ToRGB(nn.Module): 336 | 337 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], device='cpu', out_channel=3): 338 | super().__init__() 339 | 340 | if upsample: 341 | self.upsample = Upsample(blur_kernel, device=device) 342 | 343 | self.conv = ModulatedConv2d(in_channel, out_channel, 1, style_dim, demodulate=False, device=device) 344 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 345 | 346 | def forward(self, input, style, skip=None): 347 | out = self.conv(input, style) 348 | out = out + self.bias 349 | 350 | if skip is not None: 351 | skip = self.upsample(skip) 352 | out = out + skip 353 | 354 | return out 355 | 356 | 357 | class ConvLayer(nn.Sequential): 358 | 359 | def __init__(self, 360 | in_channel, 361 | out_channel, 362 | kernel_size, 363 | downsample=False, 364 | blur_kernel=[1, 3, 3, 1], 365 | bias=True, 366 | activate=True, 367 | device='cpu'): 368 | layers = [] 369 | 370 | if downsample: 371 | factor = 2 372 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 373 | pad0 = (p + 1) // 2 374 | pad1 = p // 2 375 | 376 | layers.append(Blur(blur_kernel, pad=(pad0, pad1), device=device)) 377 | 378 | stride = 2 379 | self.padding = 0 380 | 381 | else: 382 | stride = 1 383 | self.padding = kernel_size // 2 384 | 385 | layers.append(EqualConv2d( 386 | in_channel, 387 | out_channel, 388 | kernel_size, 389 | padding=self.padding, 390 | stride=stride, 391 | bias=bias and not activate, 392 | )) 393 | 394 | if activate: 395 | if bias: 396 | layers.append(FusedLeakyReLU(out_channel, device=device)) 397 | else: 398 | layers.append(ScaledLeakyReLU(0.2)) 399 | 400 | super().__init__(*layers) 401 | 402 | 403 | class ResBlock(nn.Module): 404 | 405 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], device='cpu'): 406 | super().__init__() 407 | 408 | self.conv1 = ConvLayer(in_channel, in_channel, 3, device=device) 409 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, device=device) 410 | 411 | self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) 412 | 413 | def forward(self, input): 414 | out = self.conv1(input) 415 | out = self.conv2(out) 416 | 417 | skip = self.skip(input) 418 | out = (out + skip) / math.sqrt(2) 419 | 420 | return out 421 | 422 | 423 | class audioConv2d(nn.Module): 424 | 425 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 426 | super().__init__(*args, **kwargs) 427 | num_groups = 32 428 | self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), 429 | nn.GroupNorm(num_groups=num_groups, num_channels=cout)) 430 | self.act = nn.LeakyReLU(0.01, inplace=True) 431 | self.residual = residual 432 | 433 | def forward(self, x): 434 | out = self.conv_block(x) 435 | if self.residual: 436 | out += x 437 | return self.act(out) 438 | 439 | 440 | class AudioEncoder(nn.Module): 441 | 442 | def __init__(self, lr_mlp=0.01, device='cpu'): 443 | super().__init__() 444 | self.encoder = nn.Sequential( 445 | audioConv2d(1, 32, kernel_size=3, stride=1, padding=1), 446 | audioConv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 447 | audioConv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 448 | audioConv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 449 | audioConv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 450 | audioConv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 451 | audioConv2d(64, 128, kernel_size=3, stride=3, padding=1), 452 | audioConv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 453 | audioConv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 454 | audioConv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 455 | audioConv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 456 | audioConv2d(256, 512, kernel_size=3, stride=1, padding=0), 457 | audioConv2d(512, 512, kernel_size=1, stride=1, padding=0), 458 | ) 459 | self.linear = nn.Sequential(EqualLinear(512, 512, activation='fused_lrelu', device=device)) 460 | 461 | def forward(self, x): 462 | x = self.encoder(x) 463 | x = x.view(x.shape[0], -1) 464 | x = self.linear(x) 465 | return x 466 | 467 | 468 | class Generator(nn.Module): 469 | 470 | def __init__( 471 | self, 472 | size, 473 | style_dim, 474 | n_mlp, 475 | channel_multiplier=2, 476 | blur_kernel=[1, 3, 3, 1], 477 | lr_mlp=0.01, 478 | isconcat=True, 479 | narrow=1, 480 | device='cpu', 481 | ): 482 | super().__init__() 483 | 484 | self.size = size 485 | self.n_mlp = n_mlp 486 | self.style_dim = style_dim 487 | self.feat_multiplier = 2 if isconcat else 1 488 | 489 | layers = [PixelNorm()] 490 | 491 | for i in range(n_mlp): 492 | layers.append(EqualLinear(style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu', device=device)) 493 | 494 | self.style = nn.Sequential(*layers) 495 | 496 | self.channels = { 497 | 4: int(512 * narrow), 498 | 8: int(512 * narrow), 499 | 16: int(512 * narrow), 500 | 32: int(512 * narrow), 501 | 64: int(256 * channel_multiplier * narrow), 502 | 128: int(128 * channel_multiplier * narrow), 503 | 256: int(64 * channel_multiplier * narrow), 504 | 512: int(32 * channel_multiplier * narrow), 505 | 1024: int(16 * channel_multiplier * narrow), 506 | 2048: int(8 * channel_multiplier * narrow) 507 | } 508 | 509 | self.input = ConstantInput(self.channels[4]) 510 | self.conv1 = StyledConv(self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel, isconcat=isconcat, device=device) 511 | self.to_rgb1 = ToRGB(self.channels[4] * self.feat_multiplier, style_dim, upsample=False, device=device) 512 | 513 | self.log_size = int(math.log(size, 2)) 514 | self.num_layers = (self.log_size - 2) * 2 + 1 515 | 516 | self.convs = nn.ModuleList() 517 | self.upsamples = nn.ModuleList() 518 | self.to_rgbs = nn.ModuleList() 519 | 520 | in_channel = self.channels[4] 521 | 522 | for i in range(3, self.log_size + 1): 523 | out_channel = self.channels[2**i] 524 | 525 | self.convs.append( 526 | StyledConv(in_channel * self.feat_multiplier, 527 | out_channel, 528 | 3, 529 | style_dim, 530 | upsample=True, 531 | blur_kernel=blur_kernel, 532 | isconcat=isconcat, 533 | device=device)) 534 | 535 | self.convs.append( 536 | StyledConv(out_channel * self.feat_multiplier, 537 | out_channel, 538 | 3, 539 | style_dim, 540 | blur_kernel=blur_kernel, 541 | isconcat=isconcat, 542 | device=device)) 543 | 544 | self.to_rgbs.append(ToRGB(out_channel * self.feat_multiplier, style_dim, device=device)) 545 | 546 | in_channel = out_channel 547 | 548 | self.n_latent = self.log_size * 2 - 2 549 | 550 | def forward( 551 | self, 552 | styles, 553 | return_latents=False, 554 | inject_index=None, 555 | truncation=1, 556 | truncation_latent=None, 557 | input_is_latent=False, 558 | noise=None, 559 | w_plus=False, 560 | delta_w=None, 561 | ): 562 | if not input_is_latent: 563 | styles = [self.style(s) for s in styles] 564 | 565 | if noise is None: 566 | noise = [None] * (2 * (self.log_size - 2) + 1) 567 | 568 | if truncation < 1: 569 | style_t = [] 570 | for style in styles: 571 | style_t.append(truncation_latent + truncation * (style - truncation_latent)) 572 | styles = style_t 573 | 574 | if len(styles) < 2: 575 | if not w_plus: 576 | inject_index = self.n_latent 577 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 578 | else: 579 | latent = styles[0] 580 | assert latent.shape[1] == self.n_latent 581 | else: 582 | if inject_index is None: 583 | inject_index = random.randint(1, self.n_latent - 1) 584 | 585 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 586 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 587 | 588 | latent = torch.cat([latent, latent2], 1) 589 | 590 | out = self.input(latent) 591 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 592 | skip = self.to_rgb1(out, latent[:, 1]) 593 | i = 1 594 | for idx, (conv1, conv2, noise1, noise2, 595 | to_rgb) in enumerate(zip(self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs)): 596 | out = conv1(out, latent[:, i], noise=noise1) 597 | out = conv2(out, latent[:, i + 1], noise=noise2) 598 | skip = to_rgb(out, latent[:, i + 2], skip) 599 | i += 2 600 | image = skip 601 | return image 602 | 603 | def make_noise(self): 604 | device = self.input.input.device 605 | noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] 606 | for i in range(3, self.log_size + 1): 607 | for _ in range(2): 608 | noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) 609 | 610 | return noises 611 | 612 | def mean_latent(self, n_latent): 613 | latent_in = torch.randn(n_latent, self.style_dim, device=self.input.input.device) 614 | latent = self.style(latent_in).mean(0, keepdim=True) 615 | return latent 616 | 617 | def get_latent(self, input): 618 | return self.style(input) 619 | 620 | 621 | class FullGenerator(nn.Module): 622 | 623 | def __init__( 624 | self, 625 | size, 626 | style_dim, 627 | n_mlp, 628 | channel_multiplier=2, 629 | blur_kernel=[1, 3, 3, 1], 630 | lr_mlp=0.01, 631 | isconcat=True, 632 | narrow=1, 633 | device='cpu', 634 | mask_p='', 635 | mask_n_noise=None, 636 | face_z=False, 637 | tune_k=None, 638 | n_mlp_tune=2, 639 | noise_channel=False, 640 | noise_mask_p=None, 641 | ): 642 | super().__init__() 643 | 644 | self.size = size 645 | self.face_z = face_z 646 | self.mask_n_noise = mask_n_noise 647 | self.mask_p = mask_p 648 | self.noise_mask_p = noise_mask_p or self.mask_p 649 | self.act = nn.Sigmoid() 650 | 651 | self.audio_encoder = AudioEncoder(device=device) 652 | 653 | channels = { 654 | 4: int(512 * narrow), 655 | 8: int(512 * narrow), 656 | 16: int(512 * narrow), 657 | 32: int(512 * narrow), 658 | 64: int(256 * channel_multiplier * narrow), 659 | 128: int(128 * channel_multiplier * narrow), 660 | 256: int(64 * channel_multiplier * narrow), 661 | 512: int(32 * channel_multiplier * narrow), 662 | 1024: int(16 * channel_multiplier * narrow), 663 | 2048: int(8 * channel_multiplier * narrow) 664 | } 665 | self.channels = channels 666 | 667 | self.log_size = int(math.log(size, 2)) 668 | self.generator = Generator( 669 | size, 670 | style_dim, 671 | n_mlp, 672 | channel_multiplier=channel_multiplier, 673 | blur_kernel=blur_kernel, 674 | lr_mlp=lr_mlp, 675 | isconcat=isconcat, 676 | narrow=narrow, 677 | device=device, 678 | ) 679 | 680 | conv = [ConvLayer(6, channels[size], 1, device=device)] 681 | self.ecd0 = nn.Sequential(*conv) 682 | in_channel = channels[size] 683 | 684 | self.names = ['ecd%d' % i for i in range(self.log_size - 1)] 685 | for i in range(self.log_size, 2, -1): 686 | out_channel = channels[2**(i - 1)] 687 | conv = [ConvLayer(in_channel, out_channel, 3, downsample=True, device=device)] 688 | setattr(self, self.names[self.log_size - i + 1], nn.Sequential(*conv)) 689 | in_channel = out_channel 690 | 691 | if self.mask_n_noise: 692 | size = self.size 693 | mask_mouth_region = cv2.imread(self.noise_mask_p) 694 | mask_mouth_region = cv2.resize(mask_mouth_region, (size, size)) 695 | mask_back_region = 1. - mask_mouth_region[:, :, 0] / 255. 696 | mask_back_region_torch = torch.from_numpy(mask_back_region).float().view(1, 1, size, size) 697 | self.mask_back_region_list = [mask_back_region_torch] 698 | if self.mask_n_noise > 1: 699 | for _ in range(1, self.mask_n_noise): 700 | size = size // 2 701 | mask_back_region = cv2.resize(mask_back_region, (size, size)) 702 | mask_back_region_torch = torch.from_numpy(mask_back_region).float().view(1, 1, size, size) 703 | self.mask_back_region_list.append(mask_back_region_torch) 704 | 705 | if self.face_z: 706 | self.final_linear = nn.Sequential(EqualLinear(channels[4] * 4 * 4, style_dim, activation='fused_lrelu', device=device)) 707 | self.cat_linear = nn.Sequential(EqualLinear(style_dim * 2, style_dim, activation='fused_lrelu', device=device)) 708 | 709 | # if tune_k: 710 | # for k in tune_k: 711 | # _logger.info('Creating modules finetuned in [{}] ...'.format(k)) 712 | # self.finetune(k, freeze_other=False, n_mlp_tune=n_mlp_tune, noise_channel=noise_channel, device=device) 713 | 714 | def forward( 715 | self, 716 | face_sequences, 717 | audio_sequences, 718 | return_latents=False, 719 | inject_index=None, 720 | truncation=1, 721 | truncation_latent=None, 722 | input_is_latent=False, 723 | inversion_init=False, 724 | sm_audio_w=0.3, 725 | a_alpha=1.25, 726 | ): 727 | face = self.tensor5to4(face_sequences) 728 | audio = self.tensor5to4_audio(audio_sequences) 729 | inputs_masked = face[:, :3] 730 | inputs_ref = face[:, 3:] 731 | 732 | audio_feat = self.audio_encoder(audio) 733 | if sm_audio_w > 0: 734 | sm_audio_feat = getattr(self, 'sm_audio_feat', None) 735 | if sm_audio_feat is None: 736 | sm_audio_feat = audio_feat 737 | sm_audio_feat = sm_audio_w * sm_audio_feat + (1 - sm_audio_w) * audio_feat 738 | audio_feat = sm_audio_feat 739 | setattr(self, 'sm_audio_feat', sm_audio_feat) 740 | if a_alpha > 0: 741 | audio_feat *= a_alpha 742 | outs = audio_feat 743 | 744 | noise = [] 745 | inputs = face 746 | for i in range(self.log_size - 1): 747 | ecd = getattr(self, self.names[i]) 748 | inputs = ecd(inputs) 749 | noise.append(inputs) 750 | face_feat_final = inputs 751 | 752 | if self.mask_n_noise: 753 | for j in range(self.mask_n_noise): 754 | noise_local = noise[j] 755 | mask_local = self.mask_back_region_list[j].type_as(noise_local).to(noise_local.device) 756 | noise[j] = noise_local * mask_local 757 | repeat_noise = list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in noise))[::-1] 758 | 759 | if self.face_z: 760 | face_feat = self.final_linear(face_feat_final.view(face_feat_final.shape[0], -1)) 761 | outs = self.cat_linear(torch.cat([outs, face_feat], dim=1)) 762 | 763 | outs = self.generator( 764 | [outs], 765 | False, 766 | inject_index, 767 | truncation, 768 | truncation_latent, 769 | input_is_latent, 770 | noise=repeat_noise[1:], 771 | ) 772 | image = self.act(outs) 773 | return image 774 | 775 | def tensor5to4(self, input): 776 | input_dim_size = len(input.size()) 777 | if input_dim_size > 4: 778 | b, c, t, h, w = input.size() 779 | input = input.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w) 780 | return input 781 | 782 | def tensor5to4_audio(self, input): 783 | input_dim_size = len(input.size()) 784 | if input_dim_size > 4: 785 | b, t, c, h, w = input.size() 786 | input = input.reshape(-1, c, h, w) 787 | return input 788 | 789 | 790 | if __name__ == '__main__': 791 | model = FullGenerator(256, 512, 8) 792 | img_batch = torch.ones(2, 6, 5, 256, 256) 793 | mel_batch = torch.ones(2, 5, 1, 80, 16) 794 | x = model(img_batch, mel_batch) 795 | print(x.shape) 796 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import torch 5 | import tempfile 6 | import subprocess 7 | import numpy as np 8 | import os.path as osp 9 | from munch import munchify 10 | 11 | import audio 12 | from stylesync_model import FullGenerator 13 | 14 | 15 | def requires_grad(model, flag=True): 16 | for p in model.parameters(): 17 | p.requires_grad = flag 18 | 19 | 20 | def load_model(path): 21 | with open(osp.join(osp.dirname(path), 'args.json')) as f: 22 | model_args = munchify(json.load(f)) 23 | assert not model_args.gpen_norm 24 | generator = create_generator(model_args) 25 | print("Load checkpoint from: {}".format(path)) 26 | checkpoint = torch.load(path, map_location='cpu') 27 | s = checkpoint["state_dict"] 28 | new_s = {} 29 | for k, v in s.items(): 30 | new_s[k.replace('module.', '')] = v 31 | model = generator 32 | model.load_state_dict(new_s) 33 | return model 34 | 35 | 36 | def create_generator(args): 37 | generator = FullGenerator(args.size, 38 | args.latent, 39 | args.n_mlp, 40 | channel_multiplier=args.channel_multiplier, 41 | narrow=args.narrow, 42 | device=args.device, 43 | mask_p=args.mask_p, 44 | mask_n_noise=args.mask_n_noise, 45 | face_z=getattr(args, 'face_z', False), 46 | noise_mask_p=getattr(args, 'noise_mask_p', None)).to(args.device) 47 | return generator 48 | 49 | 50 | def read_wav(wav_path, fps=25, mel_step_size=16): 51 | temp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav") 52 | if not wav_path.endswith('.wav'): 53 | print('Extracting raw audio...') 54 | audio_name = temp_audio_file.name 55 | command = 'ffmpeg -i %s -loglevel error -y -f wav -acodec pcm_s16le -ar 16000 %s' % (wav_path, audio_name) 56 | subprocess.call(command, shell=True) 57 | wav_path = audio_name 58 | 59 | wav = audio.load_wav(wav_path, 16000) 60 | mel = audio.melspectrogram(wav) 61 | if np.isnan(mel.reshape(-1)).sum() > 0: 62 | raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') 63 | mel_chunks = [] 64 | mel_idx_multiplier = 80. / fps 65 | i = 0 66 | while 1: 67 | start_idx = int(i * mel_idx_multiplier) 68 | if start_idx + mel_step_size > len(mel[0]): 69 | break 70 | mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size]) 71 | i += 1 72 | temp_audio_file.close() 73 | return mel_chunks, mel 74 | 75 | 76 | def transformation_from_points(points1, points0, smooth=True, p_bias=None): 77 | points2 = np.array(points0) 78 | points2 = points2.astype(np.float64) 79 | points1 = points1.astype(np.float64) 80 | c1 = np.mean(points1, axis=0) 81 | c2 = np.mean(points2, axis=0) 82 | points1 -= c1 83 | points2 -= c2 84 | s1 = np.std(points1) 85 | s2 = np.std(points2) 86 | points1 /= s1 87 | points2 /= s2 88 | U, S, Vt = np.linalg.svd(np.matmul(points1.T, points2)) 89 | R = (np.matmul(U, Vt)).T 90 | sR = (s2 / s1) * R 91 | T = c2.reshape(2, 1) - (s2 / s1) * np.matmul(R, c1.reshape(2, 1)) 92 | M = np.concatenate((sR, T), axis=1) 93 | if smooth: 94 | bias = points2[2] - points1[2] 95 | if p_bias is None: 96 | p_bias = bias 97 | else: 98 | bias = p_bias * 0.2 + bias * 0.8 99 | p_bias = bias 100 | M[:, 2] = M[:, 2] + bias 101 | return M, p_bias 102 | 103 | 104 | class AlignRestore(object): 105 | 106 | def __init__(self, align_points=3): 107 | if align_points == 3: 108 | self.upscale_factor = 1 109 | self.crop_ratio = (2.8, 2.8) 110 | self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]]) 111 | self.face_template = self.face_template * 2.8 112 | self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1])) 113 | self.p_bias = None 114 | 115 | def process(self, img, lmk_align=None, smooth=True, align_points=3): 116 | aligned_face, affine_matrix = self.align_warp_face(img, lmk_align, smooth) 117 | restored_img = self.restore_img(img, aligned_face, affine_matrix) 118 | cv2.imwrite("restored.jpg", restored_img) 119 | cv2.imwrite("aligned.jpg", aligned_face) 120 | return aligned_face, restored_img 121 | 122 | def align_warp_face(self, img, lmks3, smooth=True, border_mode='constant'): 123 | affine_matrix, self.p_bias = transformation_from_points(lmks3, self.face_template, smooth, self.p_bias) 124 | if border_mode == 'constant': 125 | border_mode = cv2.BORDER_CONSTANT 126 | elif border_mode == 'reflect101': 127 | border_mode = cv2.BORDER_REFLECT101 128 | elif border_mode == 'reflect': 129 | border_mode = cv2.BORDER_REFLECT 130 | cropped_face = cv2.warpAffine(img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=[127, 127, 127]) 131 | return cropped_face, affine_matrix 132 | 133 | def align_warp_face2(self, img, landmark, border_mode='constant'): 134 | affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template)[0] 135 | if border_mode == 'constant': 136 | border_mode = cv2.BORDER_CONSTANT 137 | elif border_mode == 'reflect101': 138 | border_mode = cv2.BORDER_REFLECT101 139 | elif border_mode == 'reflect': 140 | border_mode = cv2.BORDER_REFLECT 141 | cropped_face = cv2.warpAffine(img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) 142 | return cropped_face, affine_matrix 143 | 144 | def restore_img(self, input_img, face, affine_matrix): 145 | h, w, _, = input_img.shape 146 | h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) 147 | upsample_img = cv2.resize(input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) 148 | inverse_affine = cv2.invertAffineTransform(affine_matrix) 149 | inverse_affine *= self.upscale_factor 150 | if self.upscale_factor > 1: 151 | extra_offset = 0.5 * self.upscale_factor 152 | else: 153 | extra_offset = 0 154 | inverse_affine[:, 2] += extra_offset 155 | inv_restored = cv2.warpAffine(face, inverse_affine, (w_up, h_up)) 156 | mask = np.ones((self.face_size[1], self.face_size[0]), dtype=np.float32) 157 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) 158 | inv_mask_erosion = cv2.erode(inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) 159 | pasted_face = inv_mask_erosion[:, :, None] * inv_restored 160 | total_face_area = np.sum(inv_mask_erosion) 161 | w_edge = int(total_face_area**0.5) // 20 162 | erosion_radius = w_edge * 2 163 | inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) 164 | blur_size = w_edge * 2 165 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) 166 | inv_soft_mask = inv_soft_mask[:, :, None] 167 | upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img 168 | if np.max(upsample_img) > 256: 169 | upsample_img = upsample_img.astype(np.uint16) 170 | else: 171 | upsample_img = upsample_img.astype(np.uint8) 172 | return upsample_img 173 | 174 | 175 | class laplacianSmooth(object): 176 | 177 | def __init__(self, smoothAlpha=0.3): 178 | self.smoothAlpha = smoothAlpha 179 | self.pts_last = None 180 | 181 | def smooth(self, pts_cur): 182 | if self.pts_last is None: 183 | self.pts_last = pts_cur.copy() 184 | return pts_cur.copy() 185 | x1 = min(pts_cur[:, 0]) 186 | x2 = max(pts_cur[:, 0]) 187 | y1 = min(pts_cur[:, 1]) 188 | y2 = max(pts_cur[:, 1]) 189 | width = x2 - x1 190 | pts_update = [] 191 | for i in range(len(pts_cur)): 192 | x_new, y_new = pts_cur[i] 193 | x_old, y_old = self.pts_last[i] 194 | tmp = (x_new - x_old)**2 + (y_new - y_old)**2 195 | w = np.exp(-tmp / (width * self.smoothAlpha)) 196 | x = x_old * w + x_new * (1 - w) 197 | y = y_old * w + y_new * (1 - w) 198 | pts_update.append([x, y]) 199 | pts_update = np.array(pts_update) 200 | self.pts_last = pts_update.copy() 201 | 202 | return pts_update --------------------------------------------------------------------------------