├── .gitignore
├── README.md
├── __init_paths.py
├── audio.py
├── hparams.py
├── inference.py
├── misc
├── StyleSync.png
└── StyleSync0.png
├── op
├── __init__.py
├── fused_act.py
├── fused_bias_act.cpp
├── fused_bias_act_kernel.cu
├── upfirdn2d.cpp
├── upfirdn2d.py
└── upfirdn2d_kernel.cu
├── requirements.txt
├── stylesync_model.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | ckpts/*
3 | results/*
4 | _README.md
5 | demo/
6 | mask*.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### This a PyTorch-implemented repo. Find the Paddle-based version [here](https://github.com/guanjz20/StyleSync).
2 |
3 | # StyleSync: High-Fidelity Generalized and Personalized Lip Sync in Style-based Generator (CVPR 2023)
4 |
5 | Jiazhi Guan*, Zhanwang Zhang*, [Hang Zhou](https://hangz-nju-cuhk.github.io/)†, [Tianshu Hu](https://scholar.google.com/citations?user=BIixVT0AAAAJ)†, [Kaisiyuan Wang](https://scholar.google.com/citations?user=2Pedf3EAAAAJ), [Dongliang He](https://scholar.google.com/citations?user=ui6DYGoAAAAJ), Haocheng Feng, [Jingtuo Liu](https://scholar.google.com/citations?user=tVV3jmcAAAAJ), [Errui Ding](https://scholar.google.com/citations?user=1wzEtxcAAAAJ), [Ziwei Liu](https://liuziwei7.github.io/), [Jingdong Wang](https://jingdongwang2017.github.io/)
6 |
7 |
8 |
9 | ### [Project](https://hangz-nju-cuhk.github.io/projects/StyleSync) | [Paper](https://arxiv.org/pdf/2305.05445.pdf) | [Video](https://www.youtube.com/watch?v=uuBglL2KGFc) | [Demo](https://www.youtube.com/watch?v=yAPDl2dVonY)
10 |
11 | We propose **StyleSync**, an effective framework that enables high-fidelity lip synchronization. We identify that a style-based generator would sufficiently enable such a charming property on both one-shot and few-shot scenarios.
12 |
13 | ## Code
14 | Inference script and model code have been released.
15 |
16 | ## Run Generation
17 | The decision to suspend the release of the model weights has been made by our team based on various considerations. At present, I also do not have a specific expected date for its release. However, if you are interested in utilizing our demo for academic purposes, such as conducting a comparison in your paper, please don't hesitate to contact me by [guanjz20 at mails dot tsinghua dot edu dot cn].
18 |
19 |
20 | ## Citation
21 | ```
22 | @inproceedings{guan2023stylesync,
23 | title = {StyleSync: High-Fidelity Generalized and Personalized Lip Sync in Style-based Generator},
24 | author = {Guan, Jiazhi and Zhang, Zhanwang and Zhou, Hang and HU, Tianshu and Wang, Kaisiyuan and He, Dongliang and Feng, Haocheng and Liu, Jingtuo and Ding, Errui and Liu, Ziwei and Wang, Jingdong},
25 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
26 | year = {2023}
27 | }
28 | ```
--------------------------------------------------------------------------------
/__init_paths.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os.path as osp
3 |
4 | def add_path(path):
5 | if path not in sys.path:
6 | sys.path.insert(0, path)
7 |
8 | this_dir = osp.dirname(__file__)
9 |
10 | WORK_ROOT = '/home/guanjz'
11 | OUT_ROOT = './result'
12 | LOG_ROOT = './log'
13 | TEMP_ROOT ='./temp'
14 | DATA_ROOT = ''
15 | VIDEO_ROOT = ''
--------------------------------------------------------------------------------
/audio.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 | from scipy import signal
4 | from scipy.io import wavfile
5 | from hparams import hparams as hp
6 |
7 |
8 | def load_wav(path, sr):
9 | return librosa.core.load(path, sr=sr)[0]
10 |
11 |
12 | def save_wav(wav, path, sr):
13 | wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14 | wavfile.write(path, sr, wav.astype(np.int16))
15 |
16 |
17 | def save_wavenet_wav(wav, path, sr):
18 | librosa.output.write_wav(path, wav, sr=sr)
19 |
20 |
21 | def preemphasis(wav, k, preemphasize=True):
22 | if preemphasize:
23 | return signal.lfilter([1, -k], [1], wav)
24 | return wav
25 |
26 |
27 | def inv_preemphasis(wav, k, inv_preemphasize=True):
28 | if inv_preemphasize:
29 | return signal.lfilter([1], [1, -k], wav)
30 | return wav
31 |
32 |
33 | def get_hop_size():
34 | hop_size = hp.hop_size
35 | if hop_size is None:
36 | assert hp.frame_shift_ms is not None
37 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
38 | return hop_size
39 |
40 |
41 | def linearspectrogram(wav):
42 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
43 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db
44 |
45 | if hp.signal_normalization:
46 | return _normalize(S)
47 | return S
48 |
49 |
50 | def melspectrogram(wav):
51 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
52 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
53 |
54 | if hp.signal_normalization:
55 | return _normalize(S)
56 | return S
57 |
58 |
59 | def _lws_processor():
60 | import lws
61 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
62 |
63 |
64 | def _stft(y):
65 | if hp.use_lws:
66 | return _lws_processor(hp).stft(y).T
67 | else:
68 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
69 |
70 |
71 | def num_frames(length, fsize, fshift):
72 | """Compute number of time frames of spectrogram
73 | """
74 | pad = (fsize - fshift)
75 | if length % fshift == 0:
76 | M = (length + pad * 2 - fsize) // fshift + 1
77 | else:
78 | M = (length + pad * 2 - fsize) // fshift + 2
79 | return M
80 |
81 |
82 | def pad_lr(x, fsize, fshift):
83 | """Compute left and right padding
84 | """
85 | M = num_frames(len(x), fsize, fshift)
86 | pad = (fsize - fshift)
87 | T = len(x) + 2 * pad
88 | r = (M - 1) * fshift + fsize - T
89 | return pad, pad + r
90 |
91 |
92 | def librosa_pad_lr(x, fsize, fshift):
93 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
94 |
95 |
96 | _mel_basis = None
97 |
98 |
99 | def _linear_to_mel(spectogram):
100 | global _mel_basis
101 | if _mel_basis is None:
102 | _mel_basis = _build_mel_basis()
103 | return np.dot(_mel_basis, spectogram)
104 |
105 |
106 | def _build_mel_basis():
107 | assert hp.fmax <= hp.sample_rate // 2
108 | return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin, fmax=hp.fmax)
109 |
110 |
111 | def _amp_to_db(x):
112 | min_level = np.exp(hp.min_level_db / 20 * np.log(10))
113 | return 20 * np.log10(np.maximum(min_level, x))
114 |
115 |
116 | def _db_to_amp(x):
117 | return np.power(10.0, (x) * 0.05)
118 |
119 |
120 | def _normalize(S):
121 | if hp.allow_clipping_in_normalization:
122 | if hp.symmetric_mels:
123 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, -hp.max_abs_value,
124 | hp.max_abs_value)
125 | else:
126 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
127 |
128 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
129 | if hp.symmetric_mels:
130 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
131 | else:
132 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
133 |
134 |
135 | def _denormalize(D):
136 | if hp.allow_clipping_in_normalization:
137 | if hp.symmetric_mels:
138 | return (((np.clip(D, -hp.max_abs_value, hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) +
139 | hp.min_level_db)
140 | else:
141 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
142 |
143 | if hp.symmetric_mels:
144 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
145 | else:
146 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
147 |
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 |
5 | class HParams:
6 |
7 | def __init__(self, **kwargs):
8 | self.data = {}
9 |
10 | for key, value in kwargs.items():
11 | self.data[key] = value
12 |
13 | def __getattr__(self, key):
14 | if key not in self.data:
15 | raise AttributeError("'HParams' object has no attribute %s" % key)
16 | return self.data[key]
17 |
18 | def set_hparam(self, key, value):
19 | self.data[key] = value
20 |
21 |
22 | hparams = HParams(
23 | num_mels=80,
24 | rescale=True,
25 | rescaling_max=0.9,
26 | use_lws=False,
27 | n_fft=800,
28 | hop_size=200,
29 | win_size=800,
30 | sample_rate=16000,
31 | frame_shift_ms=None,
32 | signal_normalization=True,
33 | allow_clipping_in_normalization=True,
34 | symmetric_mels=True,
35 | max_abs_value=4.,
36 | preemphasize=True,
37 | preemphasis=0.97,
38 | min_level_db=-100,
39 | ref_level_db=20,
40 | fmin=55,
41 | fmax=7600,
42 | img_size=256,
43 | fps=25,
44 | initial_learning_rate=1e-4,
45 | nepochs=200000000000000000,
46 | num_workers=4,
47 | checkpoint_interval=10000,
48 | eval_interval=10000,
49 | save_optimizer_state=False,
50 | syncnet_lr=1e-4, #1e-4
51 | syncnet_eval_interval=4000,
52 | syncnet_checkpoint_interval=4000,
53 | disc_initial_learning_rate=1e-4,
54 | syncnet_batch_size=256,
55 | batch_size=16,
56 | )
57 |
58 |
59 | def hparams_debug_string():
60 | values = hparams.values()
61 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
62 | return "Hyperparameters:\n" + "\n".join(hp)
63 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import sys
4 | import random
5 | import argparse
6 | import torch
7 | import tempfile
8 | import subprocess
9 | import face_alignment
10 | import numpy as np
11 | import os.path as osp
12 | from tqdm import tqdm
13 | from glob import glob
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--checkpoint_path', type=str, required=True)
17 | parser.add_argument('--face', type=str, required=True)
18 | parser.add_argument('--audio', type=str, required=True)
19 | parser.add_argument('--max_mel_chunks', type=int, default=None)
20 | parser.add_argument('--save_root', type=str, default='./results')
21 | parser.add_argument('--save_name', type=str, default=None)
22 | parser.add_argument('--fps', type=float, default=25)
23 | parser.add_argument('--img_size', type=int, default=256)
24 | parser.add_argument('--mask', default="mask3.jpg", type=str)
25 | # run
26 | parser.add_argument('--gpus', type=str, default='0')
27 | parser.add_argument('--device', type=str, default='cuda')
28 | parser.add_argument('--save_crop', action='store_true')
29 | parser.add_argument('--tag', type=str, default='StyleSync Inference...')
30 |
31 | args = parser.parse_args()
32 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
33 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
34 |
35 | import __init_paths
36 | import audio
37 | import utils
38 |
39 |
40 | def main():
41 | # face
42 | temp_face_file = tempfile.NamedTemporaryFile(suffix=".mp4")
43 | if not os.path.isfile(args.face):
44 | fnames = list(glob(os.path.join(args.face, '*.jpg')))
45 | sorted_fnames = sorted(fnames, key=lambda f: int(os.path.basename(f).split('.')[0]))
46 | full_frames = [cv2.imread(f) for f in sorted_fnames]
47 | fps = args.fps
48 | elif args.face.split('.')[-1] in ['jpg', 'png', 'jpeg']:
49 | full_frames = [cv2.imread(args.face)]
50 | fps = args.fps
51 | elif args.face.split('.')[-1] in ['mp4', 'mov', 'MOV', 'MP4', 'webm']:
52 | video_stream = cv2.VideoCapture(args.face)
53 | fps = video_stream.get(cv2.CAP_PROP_FPS)
54 | try:
55 | assert fps == args.fps
56 | except:
57 | print('Converting video to fps 25...')
58 | video_name = temp_face_file.name
59 | command = 'ffmpeg -loglevel panic -i {} -qscale 0 -strict -2 -r {} -y {}'.format(args.face, fps, video_name)
60 | subprocess.call(command, shell=True)
61 | video_stream = cv2.VideoCapture(video_name)
62 | print('Reading video frames...')
63 | full_frames = []
64 | while 1:
65 | print('Reading {}...'.format(len(full_frames)), end='\r')
66 | still_reading, frame = video_stream.read()
67 | if not still_reading:
68 | video_stream.release()
69 | break
70 | full_frames.append(frame)
71 | if args.max_mel_chunks and len(full_frames) > args.max_mel_chunks + 10:
72 | video_stream.release()
73 | break
74 |
75 | # audio
76 | temp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav")
77 | if osp.basename(args.audio).split('.')[1] in ['wav', 'mp3']:
78 | wav_path = args.audio
79 | elif os.path.basename(args.audio).split('.')[1] in ['mp4', 'avi', 'MP4', 'AVI', 'MOV', 'mov', 'webm']:
80 | print('Extracting raw audio...')
81 | audio_name = temp_audio_file.name
82 | command = 'ffmpeg -i {} -loglevel error -y -f wav -acodec pcm_s16le -ar 16000 {}'.format(args.audio, audio_name)
83 | subprocess.call(command, shell=True)
84 | wav_path = audio_name
85 |
86 | # run
87 | with torch.no_grad():
88 | print("Loading model...")
89 | model = utils.load_model(args.checkpoint_path)
90 | model = model.to(args.device)
91 | model.eval()
92 |
93 | save_name = args.save_name or '{}_{}.mp4'.format(osp.basename(args.face).split('.')[0], osp.basename(args.audio).split('.')[0])
94 | save_path = os.path.join(args.save_root, save_name)
95 | os.makedirs(osp.dirname(save_path), exist_ok=True)
96 | print("=====>", save_path)
97 | infer_one(model, full_frames, wav_path, save_path)
98 |
99 | temp_face_file.close()
100 | temp_audio_file.close()
101 |
102 |
103 | def infer_one(model, imgs, wav_path, save_path):
104 | out_video_p = tempfile.NamedTemporaryFile(suffix=".avi")
105 | crop_out_video_p = tempfile.NamedTemporaryFile(suffix=".avi")
106 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=args.device)
107 | restorer = utils.AlignRestore()
108 | lmk_smoother = utils.laplacianSmooth()
109 | mel_chunks, _ = utils.read_wav(wav_path)
110 | mel_chunks = mel_chunks[:args.max_mel_chunks]
111 | img_mask = 1. - cv2.resize(cv2.imread(args.mask), (args.img_size, args.img_size), interpolation=cv2.INTER_AREA) / 255.
112 | img_idxs_org = list(range(len(imgs)))
113 | img_idxs_dst = []
114 | while (len(img_idxs_dst) < len(mel_chunks) + 10):
115 | img_idxs_dst += img_idxs_org
116 | img_idxs_org = img_idxs_org[::-1]
117 | frame_h, frame_w, _ = imgs[0].shape
118 | out = cv2.VideoWriter(out_video_p.name, cv2.VideoWriter_fourcc(*'DIVX'), 25, (frame_w, frame_h))
119 |
120 | # img prep
121 | skip_begin = 0
122 | face_all = []
123 | box_all = []
124 | affine_matrix_all = []
125 | face_crop_data_dict = {}
126 | print('Run face cropping...')
127 | for i, m in tqdm(enumerate(mel_chunks), total=len(mel_chunks)):
128 | img_idx = img_idxs_dst[i]
129 | if img_idx in face_crop_data_dict:
130 | _data = face_crop_data_dict[img_idx]
131 | affine_matrix_all.append(_data['affine_matrix'])
132 | box_all.append(_data['box'])
133 | face_all.append(_data['face'])
134 | continue
135 | img = imgs[img_idx].copy()
136 | try:
137 | re_lmks = fa.get_landmarks(img.copy())
138 | points = lmk_smoother.smooth(re_lmks[0])
139 | lmk3_ = np.zeros((3, 2))
140 | lmk3_[0] = points[17:22].mean(0)
141 | lmk3_[1] = points[22:27].mean(0)
142 | lmk3_[2] = points[27:36].mean(0)
143 | except Exception as e:
144 | print('Face detection fail...\n[{}]'.format(e))
145 | if len(affine_matrix_all) == 0:
146 | skip_begin += 1
147 | continue
148 | affine_matrix_all.append(affine_matrix_all[-1])
149 | box_all.append(box_all[-1])
150 | face_all.append(face_all[-1])
151 | continue
152 | face, affine_matrix = restorer.align_warp_face(img.copy(), lmks3=lmk3_, smooth=True)
153 | box = [0, 0, face.shape[1], face.shape[0]]
154 | if i == 0 and args.save_crop:
155 | out_crop = cv2.VideoWriter(crop_out_video_p.name, cv2.VideoWriter_fourcc(*'DIVX'), 25, (face.shape[1], face.shape[0]))
156 | face = cv2.resize(face, (args.img_size, args.img_size), interpolation=cv2.INTER_CUBIC)
157 | affine_matrix_all.append(affine_matrix)
158 | box_all.append(box)
159 | face_all.append(face)
160 | face_crop_data_dict[img_idx] = {'affine_matrix': affine_matrix, 'box': box, 'face': face}
161 | while len(face_all) < len(mel_chunks):
162 | assert skip_begin > 0
163 | affine_matrix_all += affine_matrix_all[::-1]
164 | box_all += box_all[::-1]
165 | face_all += face_all[::-1]
166 |
167 | print('Run generation...')
168 | for i, m in tqdm(enumerate(mel_chunks), total=len(mel_chunks)):
169 | img = imgs[img_idxs_dst[i]].copy()
170 | face = face_all[i].copy()
171 | box = box_all[i]
172 | affine_matrix = affine_matrix_all[i]
173 | face_masked = face.copy() * img_mask
174 | ref_face = face.copy()
175 |
176 | # infer
177 | ref_faces = ref_face[np.newaxis, :]
178 | face_masked = face_masked[np.newaxis, :]
179 | mel_batch = mel_chunks[i][np.newaxis, :]
180 | img_batch = np.concatenate((face_masked, ref_faces), axis=3) / 255.
181 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
182 | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(args.device)
183 | mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(args.device)
184 | pred = model(img_batch, mel_batch)
185 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
186 | pred = pred.astype(np.uint8)[0]
187 |
188 | # save
189 | x1, y1, x2, y2 = box
190 | pred = cv2.resize(pred, (x2 - x1, y2 - y1), interpolation=cv2.INTER_CUBIC)
191 | if args.save_crop:
192 | out_crop.write(pred)
193 | out_img = restorer.restore_img(img, pred, affine_matrix)
194 | out.write(out_img)
195 |
196 | # write video
197 | out.release()
198 | command = 'ffmpeg -loglevel panic -y -i {} -i {} -vcodec libx264 -crf 12 -pix_fmt yuv420p -shortest {}'.format(
199 | wav_path, out_video_p.name, save_path)
200 | subprocess.call(command, shell=True)
201 | if args.save_crop:
202 | out_crop.release()
203 | command = 'ffmpeg -loglevel panic -y -i {} -i {} -vcodec libx264 -crf 12 -pix_fmt yuv420p -shortest {}'.format(
204 | wav_path, crop_out_video_p.name, save_path[:-4] + "_crop.mp4")
205 | subprocess.call(command, shell=True)
206 |
207 | out_video_p.close()
208 | crop_out_video_p.close()
209 | print('[DONE]')
210 |
211 |
212 | if __name__ == '__main__':
213 | main()
214 |
--------------------------------------------------------------------------------
/misc/StyleSync.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanjz20/StyleSync_PyTorch/ecafd9016a204bf39a4137268f432ef778ff93d6/misc/StyleSync.png
--------------------------------------------------------------------------------
/misc/StyleSync0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanjz20/StyleSync_PyTorch/ecafd9016a204bf39a4137268f432ef778ff93d6/misc/StyleSync0.png
--------------------------------------------------------------------------------
/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Function
8 | from torch.utils.cpp_extension import load, _import_module_from_library
9 |
10 | # if running GPEN without cuda, please comment line 11-19
11 | if platform.system() == 'Linux' and torch.cuda.is_available():
12 | module_path = os.path.dirname(__file__)
13 | fused = load(
14 | 'fused',
15 | sources=[
16 | os.path.join(module_path, 'fused_bias_act.cpp'),
17 | os.path.join(module_path, 'fused_bias_act_kernel.cu'),
18 | ],
19 | )
20 |
21 |
22 | class FusedLeakyReLUFunctionBackward(Function):
23 |
24 | @staticmethod
25 | def forward(ctx, grad_output, out, negative_slope, scale):
26 | ctx.save_for_backward(out)
27 | ctx.negative_slope = negative_slope
28 | ctx.scale = scale
29 |
30 | empty = grad_output.new_empty(0)
31 |
32 | grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
33 |
34 | dim = [0]
35 |
36 | if grad_input.ndim > 2:
37 | dim += list(range(2, grad_input.ndim))
38 |
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | return grad_input, grad_bias
42 |
43 | @staticmethod
44 | def backward(ctx, gradgrad_input, gradgrad_bias):
45 | out, = ctx.saved_tensors
46 | gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale)
47 |
48 | return gradgrad_out, None, None, None
49 |
50 |
51 | class FusedLeakyReLUFunction(Function):
52 |
53 | @staticmethod
54 | def forward(ctx, input, bias, negative_slope, scale):
55 | empty = input.new_empty(0)
56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
57 | ctx.save_for_backward(out)
58 | ctx.negative_slope = negative_slope
59 | ctx.scale = scale
60 |
61 | return out
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | out, = ctx.saved_tensors
66 |
67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
68 |
69 | return grad_input, grad_bias, None, None
70 |
71 |
72 | class FusedLeakyReLU(nn.Module):
73 |
74 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5, device='cpu'):
75 | super().__init__()
76 |
77 | self.bias = nn.Parameter(torch.zeros(channel))
78 | self.negative_slope = negative_slope
79 | self.scale = scale
80 | self.device = device
81 |
82 | def forward(self, input):
83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale, self.device)
84 |
85 |
86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5, device='cpu'):
87 | if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu':
88 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
89 | else:
90 | return scale * F.leaky_relu(input + bias.view((1, -1) + (1,) * (len(input.shape) - 2)), negative_slope=negative_slope)
91 |
--------------------------------------------------------------------------------
/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load, _import_module_from_library
8 |
9 | # if running GPEN without cuda, please comment line 10-18
10 | if platform.system() == 'Linux' and torch.cuda.is_available():
11 | module_path = os.path.dirname(__file__)
12 | upfirdn2d_op = load(
13 | 'upfirdn2d',
14 | sources=[
15 | os.path.join(module_path, 'upfirdn2d.cpp'),
16 | os.path.join(module_path, 'upfirdn2d_kernel.cu'),
17 | ],
18 | )
19 |
20 |
21 | class UpFirDn2dBackward(Function):
22 |
23 | @staticmethod
24 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
25 |
26 | up_x, up_y = up
27 | down_x, down_y = down
28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29 |
30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31 |
32 | grad_input = upfirdn2d_op.upfirdn2d(
33 | grad_output,
34 | grad_kernel,
35 | down_x,
36 | down_y,
37 | up_x,
38 | up_y,
39 | g_pad_x0,
40 | g_pad_x1,
41 | g_pad_y0,
42 | g_pad_y1,
43 | )
44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45 |
46 | ctx.save_for_backward(kernel)
47 |
48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
49 |
50 | ctx.up_x = up_x
51 | ctx.up_y = up_y
52 | ctx.down_x = down_x
53 | ctx.down_y = down_y
54 | ctx.pad_x0 = pad_x0
55 | ctx.pad_x1 = pad_x1
56 | ctx.pad_y0 = pad_y0
57 | ctx.pad_y1 = pad_y1
58 | ctx.in_size = in_size
59 | ctx.out_size = out_size
60 |
61 | return grad_input
62 |
63 | @staticmethod
64 | def backward(ctx, gradgrad_input):
65 | kernel, = ctx.saved_tensors
66 |
67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68 |
69 | gradgrad_out = upfirdn2d_op.upfirdn2d(
70 | gradgrad_input,
71 | kernel,
72 | ctx.up_x,
73 | ctx.up_y,
74 | ctx.down_x,
75 | ctx.down_y,
76 | ctx.pad_x0,
77 | ctx.pad_x1,
78 | ctx.pad_y0,
79 | ctx.pad_y1,
80 | )
81 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
82 |
83 | return gradgrad_out, None, None, None, None, None, None, None, None
84 |
85 |
86 | class UpFirDn2d(Function):
87 |
88 | @staticmethod
89 | def forward(ctx, input, kernel, up, down, pad):
90 | up_x, up_y = up
91 | down_x, down_y = down
92 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
93 |
94 | kernel_h, kernel_w = kernel.shape
95 | batch, channel, in_h, in_w = input.shape
96 | ctx.in_size = input.shape
97 |
98 | input = input.reshape(-1, in_h, in_w, 1)
99 |
100 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101 |
102 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104 | ctx.out_size = (out_h, out_w)
105 |
106 | ctx.up = (up_x, up_y)
107 | ctx.down = (down_x, down_y)
108 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109 |
110 | g_pad_x0 = kernel_w - pad_x0 - 1
111 | g_pad_y0 = kernel_h - pad_y0 - 1
112 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114 |
115 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116 |
117 | out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
118 | out = out.view(-1, channel, out_h, out_w)
119 |
120 | return out
121 |
122 | @staticmethod
123 | def backward(ctx, grad_output):
124 | kernel, grad_kernel = ctx.saved_tensors
125 |
126 | grad_input = UpFirDn2dBackward.apply(
127 | grad_output,
128 | kernel,
129 | grad_kernel,
130 | ctx.up,
131 | ctx.down,
132 | ctx.pad,
133 | ctx.g_pad,
134 | ctx.in_size,
135 | ctx.out_size,
136 | )
137 |
138 | return grad_input, None, None, None, None
139 |
140 |
141 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0), device='cpu'):
142 | if platform.system() == 'Linux' and torch.cuda.is_available() and device != 'cpu':
143 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
144 | else:
145 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
146 |
147 | return out
148 |
149 |
150 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
151 | input = input.permute(0, 2, 3, 1)
152 | _, in_h, in_w, minor = input.shape
153 | kernel_h, kernel_w = kernel.shape
154 | out = input.view(-1, in_h, 1, in_w, 1, minor)
155 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
156 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
157 |
158 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
159 | out = out[
160 | :,
161 | max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
162 | max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0),
163 | :,
164 | ]
165 |
166 | out = out.permute(0, 3, 1, 2)
167 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
168 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
169 | out = F.conv2d(out, w)
170 | out = out.reshape(
171 | -1,
172 | minor,
173 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
174 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
175 | )
176 | return out[:, :, ::down_y, ::down_x]
177 |
--------------------------------------------------------------------------------
/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19 | int c = a / b;
20 |
21 | if (c * b > a) {
22 | c--;
23 | }
24 |
25 | return c;
26 | }
27 |
28 |
29 | struct UpFirDn2DKernelParams {
30 | int up_x;
31 | int up_y;
32 | int down_x;
33 | int down_y;
34 | int pad_x0;
35 | int pad_x1;
36 | int pad_y0;
37 | int pad_y1;
38 |
39 | int major_dim;
40 | int in_h;
41 | int in_w;
42 | int minor_dim;
43 | int kernel_h;
44 | int kernel_w;
45 | int out_h;
46 | int out_w;
47 | int loop_major;
48 | int loop_x;
49 | };
50 |
51 |
52 | template
53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56 |
57 | __shared__ volatile float sk[kernel_h][kernel_w];
58 | __shared__ volatile float sx[tile_in_h][tile_in_w];
59 |
60 | int minor_idx = blockIdx.x;
61 | int tile_out_y = minor_idx / p.minor_dim;
62 | minor_idx -= tile_out_y * p.minor_dim;
63 | tile_out_y *= tile_out_h;
64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65 | int major_idx_base = blockIdx.z * p.loop_major;
66 |
67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68 | return;
69 | }
70 |
71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72 | int ky = tap_idx / kernel_w;
73 | int kx = tap_idx - ky * kernel_w;
74 | scalar_t v = 0.0;
75 |
76 | if (kx < p.kernel_w & ky < p.kernel_h) {
77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78 | }
79 |
80 | sk[ky][kx] = v;
81 | }
82 |
83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87 | int tile_in_x = floor_div(tile_mid_x, up_x);
88 | int tile_in_y = floor_div(tile_mid_y, up_y);
89 |
90 | __syncthreads();
91 |
92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93 | int rel_in_y = in_idx / tile_in_w;
94 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
95 | int in_x = rel_in_x + tile_in_x;
96 | int in_y = rel_in_y + tile_in_y;
97 |
98 | scalar_t v = 0.0;
99 |
100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102 | }
103 |
104 | sx[rel_in_y][rel_in_x] = v;
105 | }
106 |
107 | __syncthreads();
108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109 | int rel_out_y = out_idx / tile_out_w;
110 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
111 | int out_x = rel_out_x + tile_out_x;
112 | int out_y = rel_out_y + tile_out_y;
113 |
114 | int mid_x = tile_mid_x + rel_out_x * down_x;
115 | int mid_y = tile_mid_y + rel_out_y * down_y;
116 | int in_x = floor_div(mid_x, up_x);
117 | int in_y = floor_div(mid_y, up_y);
118 | int rel_in_x = in_x - tile_in_x;
119 | int rel_in_y = in_y - tile_in_y;
120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122 |
123 | scalar_t v = 0.0;
124 |
125 | #pragma unroll
126 | for (int y = 0; y < kernel_h / up_y; y++)
127 | #pragma unroll
128 | for (int x = 0; x < kernel_w / up_x; x++)
129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130 |
131 | if (out_x < p.out_w & out_y < p.out_h) {
132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133 | }
134 | }
135 | }
136 | }
137 | }
138 |
139 |
140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141 | int up_x, int up_y, int down_x, int down_y,
142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143 | int curDevice = -1;
144 | cudaGetDevice(&curDevice);
145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146 |
147 | UpFirDn2DKernelParams p;
148 |
149 | auto x = input.contiguous();
150 | auto k = kernel.contiguous();
151 |
152 | p.major_dim = x.size(0);
153 | p.in_h = x.size(1);
154 | p.in_w = x.size(2);
155 | p.minor_dim = x.size(3);
156 | p.kernel_h = k.size(0);
157 | p.kernel_w = k.size(1);
158 | p.up_x = up_x;
159 | p.up_y = up_y;
160 | p.down_x = down_x;
161 | p.down_y = down_y;
162 | p.pad_x0 = pad_x0;
163 | p.pad_x1 = pad_x1;
164 | p.pad_y0 = pad_y0;
165 | p.pad_y1 = pad_y1;
166 |
167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169 |
170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171 |
172 | int mode = -1;
173 |
174 | int tile_out_h;
175 | int tile_out_w;
176 |
177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178 | mode = 1;
179 | tile_out_h = 16;
180 | tile_out_w = 64;
181 | }
182 |
183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184 | mode = 2;
185 | tile_out_h = 16;
186 | tile_out_w = 64;
187 | }
188 |
189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190 | mode = 3;
191 | tile_out_h = 16;
192 | tile_out_w = 64;
193 | }
194 |
195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196 | mode = 4;
197 | tile_out_h = 16;
198 | tile_out_w = 64;
199 | }
200 |
201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202 | mode = 5;
203 | tile_out_h = 8;
204 | tile_out_w = 32;
205 | }
206 |
207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208 | mode = 6;
209 | tile_out_h = 8;
210 | tile_out_w = 32;
211 | }
212 |
213 | dim3 block_size;
214 | dim3 grid_size;
215 |
216 | if (tile_out_h > 0 && tile_out_w) {
217 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
218 | p.loop_x = 1;
219 | block_size = dim3(32 * 8, 1, 1);
220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222 | (p.major_dim - 1) / p.loop_major + 1);
223 | }
224 |
225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226 | switch (mode) {
227 | case 1:
228 | upfirdn2d_kernel<<>>(
229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
230 | );
231 |
232 | break;
233 |
234 | case 2:
235 | upfirdn2d_kernel<<>>(
236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
237 | );
238 |
239 | break;
240 |
241 | case 3:
242 | upfirdn2d_kernel<<>>(
243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
244 | );
245 |
246 | break;
247 |
248 | case 4:
249 | upfirdn2d_kernel<<>>(
250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
251 | );
252 |
253 | break;
254 |
255 | case 5:
256 | upfirdn2d_kernel<<>>(
257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
258 | );
259 |
260 | break;
261 |
262 | case 6:
263 | upfirdn2d_kernel<<>>(
264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
265 | );
266 |
267 | break;
268 | }
269 | });
270 |
271 | return out;
272 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa==0.9.2
2 | matplotlib==3.5.1
3 | munch==2.5.0
4 | ninja==1.10.2.4
5 | opencv-python==4.6.0.66
6 | opencv-python-headless==4.6.0.66
7 | scikit-image==0.19.2
8 | torch==1.12.1
9 | torchvision==0.13.1
10 | tqdm==4.64.0
11 | timm==0.6.13
12 | face-alignment==1.3.5
--------------------------------------------------------------------------------
/stylesync_model.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import re
3 | import math
4 | import random
5 | import itertools
6 | import logging
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | import utils
12 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
13 |
14 | _logger = logging.getLogger('model')
15 |
16 |
17 | def make_kernel(k):
18 | k = torch.tensor(k, dtype=torch.float32)
19 |
20 | if k.ndim == 1:
21 | k = k[None, :] * k[:, None]
22 |
23 | k /= k.sum()
24 |
25 | return k
26 |
27 |
28 | class PixelNorm(nn.Module):
29 |
30 | def __init__(self):
31 | super().__init__()
32 |
33 | def forward(self, input):
34 | return input * torch.rsqrt(torch.mean(input**2, dim=1, keepdim=True) + 1e-8)
35 |
36 |
37 | class Upsample(nn.Module):
38 |
39 | def __init__(self, kernel, factor=2, device='cpu'):
40 | super().__init__()
41 |
42 | self.factor = factor
43 | kernel = make_kernel(kernel) * (factor**2)
44 | self.register_buffer('kernel', kernel)
45 |
46 | p = kernel.shape[0] - factor
47 |
48 | pad0 = (p + 1) // 2 + factor - 1
49 | pad1 = p // 2
50 |
51 | self.pad = (pad0, pad1)
52 | self.device = device
53 |
54 | def forward(self, input):
55 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad, device=self.device)
56 |
57 | return out
58 |
59 |
60 | class Downsample(nn.Module):
61 |
62 | def __init__(self, kernel, factor=2, device='cpu'):
63 | super().__init__()
64 |
65 | self.factor = factor
66 | kernel = make_kernel(kernel)
67 | self.register_buffer('kernel', kernel)
68 |
69 | p = kernel.shape[0] - factor
70 |
71 | pad0 = (p + 1) // 2
72 | pad1 = p // 2
73 |
74 | self.pad = (pad0, pad1)
75 | self.device = device
76 |
77 | def forward(self, input):
78 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad, device=self.device)
79 |
80 | return out
81 |
82 |
83 | class Blur(nn.Module):
84 |
85 | def __init__(self, kernel, pad, upsample_factor=1, device='cpu'):
86 | super().__init__()
87 |
88 | kernel = make_kernel(kernel)
89 |
90 | if upsample_factor > 1:
91 | kernel = kernel * (upsample_factor**2)
92 |
93 | self.register_buffer('kernel', kernel)
94 |
95 | self.pad = pad
96 | self.device = device
97 |
98 | def forward(self, input):
99 | out = upfirdn2d(input, self.kernel, pad=self.pad, device=self.device)
100 |
101 | return out
102 |
103 |
104 | class EqualConv2d(nn.Module):
105 |
106 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
107 | super().__init__()
108 |
109 | self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
110 | self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
111 |
112 | self.stride = stride
113 | self.padding = padding
114 |
115 | if bias:
116 | self.bias = nn.Parameter(torch.zeros(out_channel))
117 | else:
118 | self.bias = None
119 |
120 | def forward(self, input):
121 | out = F.conv2d(
122 | input,
123 | self.weight * self.scale,
124 | bias=self.bias,
125 | stride=self.stride,
126 | padding=self.padding,
127 | )
128 |
129 | return out
130 |
131 | def __repr__(self):
132 | return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
133 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})')
134 |
135 |
136 | class EqualLinear(nn.Module):
137 |
138 | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None, device='cpu'):
139 | super().__init__()
140 |
141 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
142 |
143 | if bias:
144 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
145 | else:
146 | self.bias = None
147 |
148 | self.activation = activation
149 | self.device = device
150 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
151 | self.lr_mul = lr_mul
152 |
153 | def forward(self, input):
154 | if self.activation:
155 | out = F.linear(input, self.weight * self.scale)
156 | out = fused_leaky_relu(out, self.bias * self.lr_mul, device=self.device)
157 | else:
158 | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
159 |
160 | return out
161 |
162 | def __repr__(self):
163 | return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
164 |
165 |
166 | class ScaledLeakyReLU(nn.Module):
167 |
168 | def __init__(self, negative_slope=0.2):
169 | super().__init__()
170 | self.negative_slope = negative_slope
171 |
172 | def forward(self, input):
173 | out = F.leaky_relu(input, negative_slope=self.negative_slope)
174 |
175 | return out * math.sqrt(2)
176 |
177 |
178 | class ModulatedConv2d(nn.Module):
179 |
180 | def __init__(self,
181 | in_channel,
182 | out_channel,
183 | kernel_size,
184 | style_dim,
185 | demodulate=True,
186 | upsample=False,
187 | downsample=False,
188 | blur_kernel=[1, 3, 3, 1],
189 | device='cpu'):
190 | super().__init__()
191 |
192 | self.eps = 1e-8
193 | self.kernel_size = kernel_size
194 | self.in_channel = in_channel
195 | self.out_channel = out_channel
196 | self.upsample = upsample
197 | self.downsample = downsample
198 |
199 | if upsample:
200 | factor = 2
201 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
202 | pad0 = (p + 1) // 2 + factor - 1
203 | pad1 = p // 2 + 1
204 |
205 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor, device=device)
206 |
207 | if downsample:
208 | factor = 2
209 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
210 | pad0 = (p + 1) // 2
211 | pad1 = p // 2
212 |
213 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), device=device)
214 |
215 | fan_in = in_channel * kernel_size**2
216 | self.scale = 1 / math.sqrt(fan_in)
217 | self.padding = kernel_size // 2
218 |
219 | self.weight = nn.Parameter(torch.randn(1, out_channel, in_channel, kernel_size, kernel_size))
220 |
221 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
222 |
223 | self.demodulate = demodulate
224 |
225 | def __repr__(self):
226 | return (f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
227 | f'upsample={self.upsample}, downsample={self.downsample})')
228 |
229 | def forward(self, input, style):
230 | batch, in_channel, height, width = input.shape
231 |
232 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
233 | weight = self.scale * self.weight * style
234 |
235 | if self.demodulate:
236 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
237 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
238 |
239 | weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size)
240 |
241 | if self.upsample:
242 | input = input.view(1, batch * in_channel, height, width)
243 | weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size)
244 | weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size)
245 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
246 | _, _, height, width = out.shape
247 | out = out.view(batch, self.out_channel, height, width)
248 | out = self.blur(out)
249 |
250 | elif self.downsample:
251 | input = self.blur(input)
252 | _, _, height, width = input.shape
253 | input = input.view(1, batch * in_channel, height, width)
254 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
255 | _, _, height, width = out.shape
256 | out = out.view(batch, self.out_channel, height, width)
257 |
258 | else:
259 | input = input.view(1, batch * in_channel, height, width)
260 | out = F.conv2d(input, weight, padding=self.padding, groups=batch)
261 | _, _, height, width = out.shape
262 | out = out.view(batch, self.out_channel, height, width)
263 |
264 | return out
265 |
266 |
267 | class NoiseInjection(nn.Module):
268 |
269 | def __init__(self, isconcat=True):
270 | super().__init__()
271 |
272 | self.isconcat = isconcat
273 | self.weight = nn.Parameter(torch.zeros(1))
274 |
275 | def forward(self, image, noise=None):
276 | if noise is None:
277 | batch, channel, height, width = image.shape
278 | noise = image.new_empty(batch, channel, height, width).normal_()
279 |
280 | if self.isconcat:
281 | return torch.cat((image, self.weight * noise), dim=1)
282 | else:
283 | return image + self.weight * noise
284 |
285 |
286 | class ConstantInput(nn.Module):
287 |
288 | def __init__(self, channel, size=4):
289 | super().__init__()
290 |
291 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
292 |
293 | def forward(self, input):
294 | batch = input.shape[0]
295 | out = self.input.repeat(batch, 1, 1, 1)
296 |
297 | return out
298 |
299 |
300 | class StyledConv(nn.Module):
301 |
302 | def __init__(self,
303 | in_channel,
304 | out_channel,
305 | kernel_size,
306 | style_dim,
307 | upsample=False,
308 | blur_kernel=[1, 3, 3, 1],
309 | demodulate=True,
310 | isconcat=True,
311 | device='cpu'):
312 | super().__init__()
313 |
314 | self.conv = ModulatedConv2d(in_channel,
315 | out_channel,
316 | kernel_size,
317 | style_dim,
318 | upsample=upsample,
319 | blur_kernel=blur_kernel,
320 | demodulate=demodulate,
321 | device=device)
322 |
323 | self.noise = NoiseInjection(isconcat)
324 | feat_multiplier = 2 if isconcat else 1
325 | self.activate = FusedLeakyReLU(out_channel * feat_multiplier, device=device)
326 |
327 | def forward(self, input, style, noise=None):
328 | out = self.conv(input, style)
329 | out = self.noise(out, noise=noise)
330 | out = self.activate(out)
331 |
332 | return out
333 |
334 |
335 | class ToRGB(nn.Module):
336 |
337 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], device='cpu', out_channel=3):
338 | super().__init__()
339 |
340 | if upsample:
341 | self.upsample = Upsample(blur_kernel, device=device)
342 |
343 | self.conv = ModulatedConv2d(in_channel, out_channel, 1, style_dim, demodulate=False, device=device)
344 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
345 |
346 | def forward(self, input, style, skip=None):
347 | out = self.conv(input, style)
348 | out = out + self.bias
349 |
350 | if skip is not None:
351 | skip = self.upsample(skip)
352 | out = out + skip
353 |
354 | return out
355 |
356 |
357 | class ConvLayer(nn.Sequential):
358 |
359 | def __init__(self,
360 | in_channel,
361 | out_channel,
362 | kernel_size,
363 | downsample=False,
364 | blur_kernel=[1, 3, 3, 1],
365 | bias=True,
366 | activate=True,
367 | device='cpu'):
368 | layers = []
369 |
370 | if downsample:
371 | factor = 2
372 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
373 | pad0 = (p + 1) // 2
374 | pad1 = p // 2
375 |
376 | layers.append(Blur(blur_kernel, pad=(pad0, pad1), device=device))
377 |
378 | stride = 2
379 | self.padding = 0
380 |
381 | else:
382 | stride = 1
383 | self.padding = kernel_size // 2
384 |
385 | layers.append(EqualConv2d(
386 | in_channel,
387 | out_channel,
388 | kernel_size,
389 | padding=self.padding,
390 | stride=stride,
391 | bias=bias and not activate,
392 | ))
393 |
394 | if activate:
395 | if bias:
396 | layers.append(FusedLeakyReLU(out_channel, device=device))
397 | else:
398 | layers.append(ScaledLeakyReLU(0.2))
399 |
400 | super().__init__(*layers)
401 |
402 |
403 | class ResBlock(nn.Module):
404 |
405 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], device='cpu'):
406 | super().__init__()
407 |
408 | self.conv1 = ConvLayer(in_channel, in_channel, 3, device=device)
409 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, device=device)
410 |
411 | self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
412 |
413 | def forward(self, input):
414 | out = self.conv1(input)
415 | out = self.conv2(out)
416 |
417 | skip = self.skip(input)
418 | out = (out + skip) / math.sqrt(2)
419 |
420 | return out
421 |
422 |
423 | class audioConv2d(nn.Module):
424 |
425 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
426 | super().__init__(*args, **kwargs)
427 | num_groups = 32
428 | self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding),
429 | nn.GroupNorm(num_groups=num_groups, num_channels=cout))
430 | self.act = nn.LeakyReLU(0.01, inplace=True)
431 | self.residual = residual
432 |
433 | def forward(self, x):
434 | out = self.conv_block(x)
435 | if self.residual:
436 | out += x
437 | return self.act(out)
438 |
439 |
440 | class AudioEncoder(nn.Module):
441 |
442 | def __init__(self, lr_mlp=0.01, device='cpu'):
443 | super().__init__()
444 | self.encoder = nn.Sequential(
445 | audioConv2d(1, 32, kernel_size=3, stride=1, padding=1),
446 | audioConv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
447 | audioConv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
448 | audioConv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
449 | audioConv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
450 | audioConv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
451 | audioConv2d(64, 128, kernel_size=3, stride=3, padding=1),
452 | audioConv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
453 | audioConv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
454 | audioConv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
455 | audioConv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
456 | audioConv2d(256, 512, kernel_size=3, stride=1, padding=0),
457 | audioConv2d(512, 512, kernel_size=1, stride=1, padding=0),
458 | )
459 | self.linear = nn.Sequential(EqualLinear(512, 512, activation='fused_lrelu', device=device))
460 |
461 | def forward(self, x):
462 | x = self.encoder(x)
463 | x = x.view(x.shape[0], -1)
464 | x = self.linear(x)
465 | return x
466 |
467 |
468 | class Generator(nn.Module):
469 |
470 | def __init__(
471 | self,
472 | size,
473 | style_dim,
474 | n_mlp,
475 | channel_multiplier=2,
476 | blur_kernel=[1, 3, 3, 1],
477 | lr_mlp=0.01,
478 | isconcat=True,
479 | narrow=1,
480 | device='cpu',
481 | ):
482 | super().__init__()
483 |
484 | self.size = size
485 | self.n_mlp = n_mlp
486 | self.style_dim = style_dim
487 | self.feat_multiplier = 2 if isconcat else 1
488 |
489 | layers = [PixelNorm()]
490 |
491 | for i in range(n_mlp):
492 | layers.append(EqualLinear(style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu', device=device))
493 |
494 | self.style = nn.Sequential(*layers)
495 |
496 | self.channels = {
497 | 4: int(512 * narrow),
498 | 8: int(512 * narrow),
499 | 16: int(512 * narrow),
500 | 32: int(512 * narrow),
501 | 64: int(256 * channel_multiplier * narrow),
502 | 128: int(128 * channel_multiplier * narrow),
503 | 256: int(64 * channel_multiplier * narrow),
504 | 512: int(32 * channel_multiplier * narrow),
505 | 1024: int(16 * channel_multiplier * narrow),
506 | 2048: int(8 * channel_multiplier * narrow)
507 | }
508 |
509 | self.input = ConstantInput(self.channels[4])
510 | self.conv1 = StyledConv(self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel, isconcat=isconcat, device=device)
511 | self.to_rgb1 = ToRGB(self.channels[4] * self.feat_multiplier, style_dim, upsample=False, device=device)
512 |
513 | self.log_size = int(math.log(size, 2))
514 | self.num_layers = (self.log_size - 2) * 2 + 1
515 |
516 | self.convs = nn.ModuleList()
517 | self.upsamples = nn.ModuleList()
518 | self.to_rgbs = nn.ModuleList()
519 |
520 | in_channel = self.channels[4]
521 |
522 | for i in range(3, self.log_size + 1):
523 | out_channel = self.channels[2**i]
524 |
525 | self.convs.append(
526 | StyledConv(in_channel * self.feat_multiplier,
527 | out_channel,
528 | 3,
529 | style_dim,
530 | upsample=True,
531 | blur_kernel=blur_kernel,
532 | isconcat=isconcat,
533 | device=device))
534 |
535 | self.convs.append(
536 | StyledConv(out_channel * self.feat_multiplier,
537 | out_channel,
538 | 3,
539 | style_dim,
540 | blur_kernel=blur_kernel,
541 | isconcat=isconcat,
542 | device=device))
543 |
544 | self.to_rgbs.append(ToRGB(out_channel * self.feat_multiplier, style_dim, device=device))
545 |
546 | in_channel = out_channel
547 |
548 | self.n_latent = self.log_size * 2 - 2
549 |
550 | def forward(
551 | self,
552 | styles,
553 | return_latents=False,
554 | inject_index=None,
555 | truncation=1,
556 | truncation_latent=None,
557 | input_is_latent=False,
558 | noise=None,
559 | w_plus=False,
560 | delta_w=None,
561 | ):
562 | if not input_is_latent:
563 | styles = [self.style(s) for s in styles]
564 |
565 | if noise is None:
566 | noise = [None] * (2 * (self.log_size - 2) + 1)
567 |
568 | if truncation < 1:
569 | style_t = []
570 | for style in styles:
571 | style_t.append(truncation_latent + truncation * (style - truncation_latent))
572 | styles = style_t
573 |
574 | if len(styles) < 2:
575 | if not w_plus:
576 | inject_index = self.n_latent
577 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
578 | else:
579 | latent = styles[0]
580 | assert latent.shape[1] == self.n_latent
581 | else:
582 | if inject_index is None:
583 | inject_index = random.randint(1, self.n_latent - 1)
584 |
585 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
586 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
587 |
588 | latent = torch.cat([latent, latent2], 1)
589 |
590 | out = self.input(latent)
591 | out = self.conv1(out, latent[:, 0], noise=noise[0])
592 | skip = self.to_rgb1(out, latent[:, 1])
593 | i = 1
594 | for idx, (conv1, conv2, noise1, noise2,
595 | to_rgb) in enumerate(zip(self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs)):
596 | out = conv1(out, latent[:, i], noise=noise1)
597 | out = conv2(out, latent[:, i + 1], noise=noise2)
598 | skip = to_rgb(out, latent[:, i + 2], skip)
599 | i += 2
600 | image = skip
601 | return image
602 |
603 | def make_noise(self):
604 | device = self.input.input.device
605 | noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
606 | for i in range(3, self.log_size + 1):
607 | for _ in range(2):
608 | noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
609 |
610 | return noises
611 |
612 | def mean_latent(self, n_latent):
613 | latent_in = torch.randn(n_latent, self.style_dim, device=self.input.input.device)
614 | latent = self.style(latent_in).mean(0, keepdim=True)
615 | return latent
616 |
617 | def get_latent(self, input):
618 | return self.style(input)
619 |
620 |
621 | class FullGenerator(nn.Module):
622 |
623 | def __init__(
624 | self,
625 | size,
626 | style_dim,
627 | n_mlp,
628 | channel_multiplier=2,
629 | blur_kernel=[1, 3, 3, 1],
630 | lr_mlp=0.01,
631 | isconcat=True,
632 | narrow=1,
633 | device='cpu',
634 | mask_p='',
635 | mask_n_noise=None,
636 | face_z=False,
637 | tune_k=None,
638 | n_mlp_tune=2,
639 | noise_channel=False,
640 | noise_mask_p=None,
641 | ):
642 | super().__init__()
643 |
644 | self.size = size
645 | self.face_z = face_z
646 | self.mask_n_noise = mask_n_noise
647 | self.mask_p = mask_p
648 | self.noise_mask_p = noise_mask_p or self.mask_p
649 | self.act = nn.Sigmoid()
650 |
651 | self.audio_encoder = AudioEncoder(device=device)
652 |
653 | channels = {
654 | 4: int(512 * narrow),
655 | 8: int(512 * narrow),
656 | 16: int(512 * narrow),
657 | 32: int(512 * narrow),
658 | 64: int(256 * channel_multiplier * narrow),
659 | 128: int(128 * channel_multiplier * narrow),
660 | 256: int(64 * channel_multiplier * narrow),
661 | 512: int(32 * channel_multiplier * narrow),
662 | 1024: int(16 * channel_multiplier * narrow),
663 | 2048: int(8 * channel_multiplier * narrow)
664 | }
665 | self.channels = channels
666 |
667 | self.log_size = int(math.log(size, 2))
668 | self.generator = Generator(
669 | size,
670 | style_dim,
671 | n_mlp,
672 | channel_multiplier=channel_multiplier,
673 | blur_kernel=blur_kernel,
674 | lr_mlp=lr_mlp,
675 | isconcat=isconcat,
676 | narrow=narrow,
677 | device=device,
678 | )
679 |
680 | conv = [ConvLayer(6, channels[size], 1, device=device)]
681 | self.ecd0 = nn.Sequential(*conv)
682 | in_channel = channels[size]
683 |
684 | self.names = ['ecd%d' % i for i in range(self.log_size - 1)]
685 | for i in range(self.log_size, 2, -1):
686 | out_channel = channels[2**(i - 1)]
687 | conv = [ConvLayer(in_channel, out_channel, 3, downsample=True, device=device)]
688 | setattr(self, self.names[self.log_size - i + 1], nn.Sequential(*conv))
689 | in_channel = out_channel
690 |
691 | if self.mask_n_noise:
692 | size = self.size
693 | mask_mouth_region = cv2.imread(self.noise_mask_p)
694 | mask_mouth_region = cv2.resize(mask_mouth_region, (size, size))
695 | mask_back_region = 1. - mask_mouth_region[:, :, 0] / 255.
696 | mask_back_region_torch = torch.from_numpy(mask_back_region).float().view(1, 1, size, size)
697 | self.mask_back_region_list = [mask_back_region_torch]
698 | if self.mask_n_noise > 1:
699 | for _ in range(1, self.mask_n_noise):
700 | size = size // 2
701 | mask_back_region = cv2.resize(mask_back_region, (size, size))
702 | mask_back_region_torch = torch.from_numpy(mask_back_region).float().view(1, 1, size, size)
703 | self.mask_back_region_list.append(mask_back_region_torch)
704 |
705 | if self.face_z:
706 | self.final_linear = nn.Sequential(EqualLinear(channels[4] * 4 * 4, style_dim, activation='fused_lrelu', device=device))
707 | self.cat_linear = nn.Sequential(EqualLinear(style_dim * 2, style_dim, activation='fused_lrelu', device=device))
708 |
709 | # if tune_k:
710 | # for k in tune_k:
711 | # _logger.info('Creating modules finetuned in [{}] ...'.format(k))
712 | # self.finetune(k, freeze_other=False, n_mlp_tune=n_mlp_tune, noise_channel=noise_channel, device=device)
713 |
714 | def forward(
715 | self,
716 | face_sequences,
717 | audio_sequences,
718 | return_latents=False,
719 | inject_index=None,
720 | truncation=1,
721 | truncation_latent=None,
722 | input_is_latent=False,
723 | inversion_init=False,
724 | sm_audio_w=0.3,
725 | a_alpha=1.25,
726 | ):
727 | face = self.tensor5to4(face_sequences)
728 | audio = self.tensor5to4_audio(audio_sequences)
729 | inputs_masked = face[:, :3]
730 | inputs_ref = face[:, 3:]
731 |
732 | audio_feat = self.audio_encoder(audio)
733 | if sm_audio_w > 0:
734 | sm_audio_feat = getattr(self, 'sm_audio_feat', None)
735 | if sm_audio_feat is None:
736 | sm_audio_feat = audio_feat
737 | sm_audio_feat = sm_audio_w * sm_audio_feat + (1 - sm_audio_w) * audio_feat
738 | audio_feat = sm_audio_feat
739 | setattr(self, 'sm_audio_feat', sm_audio_feat)
740 | if a_alpha > 0:
741 | audio_feat *= a_alpha
742 | outs = audio_feat
743 |
744 | noise = []
745 | inputs = face
746 | for i in range(self.log_size - 1):
747 | ecd = getattr(self, self.names[i])
748 | inputs = ecd(inputs)
749 | noise.append(inputs)
750 | face_feat_final = inputs
751 |
752 | if self.mask_n_noise:
753 | for j in range(self.mask_n_noise):
754 | noise_local = noise[j]
755 | mask_local = self.mask_back_region_list[j].type_as(noise_local).to(noise_local.device)
756 | noise[j] = noise_local * mask_local
757 | repeat_noise = list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in noise))[::-1]
758 |
759 | if self.face_z:
760 | face_feat = self.final_linear(face_feat_final.view(face_feat_final.shape[0], -1))
761 | outs = self.cat_linear(torch.cat([outs, face_feat], dim=1))
762 |
763 | outs = self.generator(
764 | [outs],
765 | False,
766 | inject_index,
767 | truncation,
768 | truncation_latent,
769 | input_is_latent,
770 | noise=repeat_noise[1:],
771 | )
772 | image = self.act(outs)
773 | return image
774 |
775 | def tensor5to4(self, input):
776 | input_dim_size = len(input.size())
777 | if input_dim_size > 4:
778 | b, c, t, h, w = input.size()
779 | input = input.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)
780 | return input
781 |
782 | def tensor5to4_audio(self, input):
783 | input_dim_size = len(input.size())
784 | if input_dim_size > 4:
785 | b, t, c, h, w = input.size()
786 | input = input.reshape(-1, c, h, w)
787 | return input
788 |
789 |
790 | if __name__ == '__main__':
791 | model = FullGenerator(256, 512, 8)
792 | img_batch = torch.ones(2, 6, 5, 256, 256)
793 | mel_batch = torch.ones(2, 5, 1, 80, 16)
794 | x = model(img_batch, mel_batch)
795 | print(x.shape)
796 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import torch
5 | import tempfile
6 | import subprocess
7 | import numpy as np
8 | import os.path as osp
9 | from munch import munchify
10 |
11 | import audio
12 | from stylesync_model import FullGenerator
13 |
14 |
15 | def requires_grad(model, flag=True):
16 | for p in model.parameters():
17 | p.requires_grad = flag
18 |
19 |
20 | def load_model(path):
21 | with open(osp.join(osp.dirname(path), 'args.json')) as f:
22 | model_args = munchify(json.load(f))
23 | assert not model_args.gpen_norm
24 | generator = create_generator(model_args)
25 | print("Load checkpoint from: {}".format(path))
26 | checkpoint = torch.load(path, map_location='cpu')
27 | s = checkpoint["state_dict"]
28 | new_s = {}
29 | for k, v in s.items():
30 | new_s[k.replace('module.', '')] = v
31 | model = generator
32 | model.load_state_dict(new_s)
33 | return model
34 |
35 |
36 | def create_generator(args):
37 | generator = FullGenerator(args.size,
38 | args.latent,
39 | args.n_mlp,
40 | channel_multiplier=args.channel_multiplier,
41 | narrow=args.narrow,
42 | device=args.device,
43 | mask_p=args.mask_p,
44 | mask_n_noise=args.mask_n_noise,
45 | face_z=getattr(args, 'face_z', False),
46 | noise_mask_p=getattr(args, 'noise_mask_p', None)).to(args.device)
47 | return generator
48 |
49 |
50 | def read_wav(wav_path, fps=25, mel_step_size=16):
51 | temp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav")
52 | if not wav_path.endswith('.wav'):
53 | print('Extracting raw audio...')
54 | audio_name = temp_audio_file.name
55 | command = 'ffmpeg -i %s -loglevel error -y -f wav -acodec pcm_s16le -ar 16000 %s' % (wav_path, audio_name)
56 | subprocess.call(command, shell=True)
57 | wav_path = audio_name
58 |
59 | wav = audio.load_wav(wav_path, 16000)
60 | mel = audio.melspectrogram(wav)
61 | if np.isnan(mel.reshape(-1)).sum() > 0:
62 | raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
63 | mel_chunks = []
64 | mel_idx_multiplier = 80. / fps
65 | i = 0
66 | while 1:
67 | start_idx = int(i * mel_idx_multiplier)
68 | if start_idx + mel_step_size > len(mel[0]):
69 | break
70 | mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
71 | i += 1
72 | temp_audio_file.close()
73 | return mel_chunks, mel
74 |
75 |
76 | def transformation_from_points(points1, points0, smooth=True, p_bias=None):
77 | points2 = np.array(points0)
78 | points2 = points2.astype(np.float64)
79 | points1 = points1.astype(np.float64)
80 | c1 = np.mean(points1, axis=0)
81 | c2 = np.mean(points2, axis=0)
82 | points1 -= c1
83 | points2 -= c2
84 | s1 = np.std(points1)
85 | s2 = np.std(points2)
86 | points1 /= s1
87 | points2 /= s2
88 | U, S, Vt = np.linalg.svd(np.matmul(points1.T, points2))
89 | R = (np.matmul(U, Vt)).T
90 | sR = (s2 / s1) * R
91 | T = c2.reshape(2, 1) - (s2 / s1) * np.matmul(R, c1.reshape(2, 1))
92 | M = np.concatenate((sR, T), axis=1)
93 | if smooth:
94 | bias = points2[2] - points1[2]
95 | if p_bias is None:
96 | p_bias = bias
97 | else:
98 | bias = p_bias * 0.2 + bias * 0.8
99 | p_bias = bias
100 | M[:, 2] = M[:, 2] + bias
101 | return M, p_bias
102 |
103 |
104 | class AlignRestore(object):
105 |
106 | def __init__(self, align_points=3):
107 | if align_points == 3:
108 | self.upscale_factor = 1
109 | self.crop_ratio = (2.8, 2.8)
110 | self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
111 | self.face_template = self.face_template * 2.8
112 | self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
113 | self.p_bias = None
114 |
115 | def process(self, img, lmk_align=None, smooth=True, align_points=3):
116 | aligned_face, affine_matrix = self.align_warp_face(img, lmk_align, smooth)
117 | restored_img = self.restore_img(img, aligned_face, affine_matrix)
118 | cv2.imwrite("restored.jpg", restored_img)
119 | cv2.imwrite("aligned.jpg", aligned_face)
120 | return aligned_face, restored_img
121 |
122 | def align_warp_face(self, img, lmks3, smooth=True, border_mode='constant'):
123 | affine_matrix, self.p_bias = transformation_from_points(lmks3, self.face_template, smooth, self.p_bias)
124 | if border_mode == 'constant':
125 | border_mode = cv2.BORDER_CONSTANT
126 | elif border_mode == 'reflect101':
127 | border_mode = cv2.BORDER_REFLECT101
128 | elif border_mode == 'reflect':
129 | border_mode = cv2.BORDER_REFLECT
130 | cropped_face = cv2.warpAffine(img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=[127, 127, 127])
131 | return cropped_face, affine_matrix
132 |
133 | def align_warp_face2(self, img, landmark, border_mode='constant'):
134 | affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template)[0]
135 | if border_mode == 'constant':
136 | border_mode = cv2.BORDER_CONSTANT
137 | elif border_mode == 'reflect101':
138 | border_mode = cv2.BORDER_REFLECT101
139 | elif border_mode == 'reflect':
140 | border_mode = cv2.BORDER_REFLECT
141 | cropped_face = cv2.warpAffine(img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132))
142 | return cropped_face, affine_matrix
143 |
144 | def restore_img(self, input_img, face, affine_matrix):
145 | h, w, _, = input_img.shape
146 | h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
147 | upsample_img = cv2.resize(input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
148 | inverse_affine = cv2.invertAffineTransform(affine_matrix)
149 | inverse_affine *= self.upscale_factor
150 | if self.upscale_factor > 1:
151 | extra_offset = 0.5 * self.upscale_factor
152 | else:
153 | extra_offset = 0
154 | inverse_affine[:, 2] += extra_offset
155 | inv_restored = cv2.warpAffine(face, inverse_affine, (w_up, h_up))
156 | mask = np.ones((self.face_size[1], self.face_size[0]), dtype=np.float32)
157 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
158 | inv_mask_erosion = cv2.erode(inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
159 | pasted_face = inv_mask_erosion[:, :, None] * inv_restored
160 | total_face_area = np.sum(inv_mask_erosion)
161 | w_edge = int(total_face_area**0.5) // 20
162 | erosion_radius = w_edge * 2
163 | inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
164 | blur_size = w_edge * 2
165 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
166 | inv_soft_mask = inv_soft_mask[:, :, None]
167 | upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
168 | if np.max(upsample_img) > 256:
169 | upsample_img = upsample_img.astype(np.uint16)
170 | else:
171 | upsample_img = upsample_img.astype(np.uint8)
172 | return upsample_img
173 |
174 |
175 | class laplacianSmooth(object):
176 |
177 | def __init__(self, smoothAlpha=0.3):
178 | self.smoothAlpha = smoothAlpha
179 | self.pts_last = None
180 |
181 | def smooth(self, pts_cur):
182 | if self.pts_last is None:
183 | self.pts_last = pts_cur.copy()
184 | return pts_cur.copy()
185 | x1 = min(pts_cur[:, 0])
186 | x2 = max(pts_cur[:, 0])
187 | y1 = min(pts_cur[:, 1])
188 | y2 = max(pts_cur[:, 1])
189 | width = x2 - x1
190 | pts_update = []
191 | for i in range(len(pts_cur)):
192 | x_new, y_new = pts_cur[i]
193 | x_old, y_old = self.pts_last[i]
194 | tmp = (x_new - x_old)**2 + (y_new - y_old)**2
195 | w = np.exp(-tmp / (width * self.smoothAlpha))
196 | x = x_old * w + x_new * (1 - w)
197 | y = y_old * w + y_new * (1 - w)
198 | pts_update.append([x, y])
199 | pts_update = np.array(pts_update)
200 | self.pts_last = pts_update.copy()
201 |
202 | return pts_update
--------------------------------------------------------------------------------